未验证 提交 bec57bcf 编写于 作者: Z zhiboniu 提交者: GitHub

[cherry-pick]cherry-pick of petr and tinypose3d_human3.6m (#7816)

* keypoint petr (#7774)

* petr train ok

train ok

refix augsize

affine size fix

update msdeformable

fix flip/affine

fix clip

add resize area

add distortion

debug mode

fix pos_inds

update edge joints

update word mistake

* delete extra codes;adapt transformer modify;update code format

* reverse old transformer modify

* integrate datasets

* add config and architecture for human36m (#7802)

* add config and architecture for human36m

* modify TinyPose3DHRNet to support human3.6M dataset

* delete useless class

---------
Co-authored-by: XYZ_916's avatarXYZ <1290573099@qq.com>
上级 93e2d433
......@@ -18,9 +18,9 @@ __pycache__/
# Distribution / packaging
/bin/
/build/
*build/
/develop-eggs/
/dist/
*dist/
/eggs/
/lib/
/lib64/
......@@ -30,7 +30,7 @@ __pycache__/
/parts/
/sdist/
/var/
/*.egg-info/
*.egg-info/
/.installed.cfg
/*.egg
/.eggs
......
......@@ -56,8 +56,10 @@ PaddleDetection 中的关键点检测部分紧跟最先进的算法,包括 Top
## 模型库
COCO数据集
| 模型 | 方案 |输入尺寸 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------| ------- |
| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) |
| HigherHRNet-w32 |Bottom-Up| 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet-w32 | Bottom-Up| 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet-w32+SWAHR |Bottom-Up| 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
......
......@@ -62,6 +62,7 @@ At the same time, PaddleDetection provides a self-developed real-time keypoint d
COCO Dataset
| Model | Input Size | AP(coco val) | Model Download | Config File |
| :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) |
| HigherHRNet-w32 | 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet-w32 | 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet-w32+SWAHR | 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
......
use_gpu: true
log_iter: 50
save_dir: output
snapshot_epoch: 1
weights: output/petr_resnet50_16x2_coco/model_final
epoch: 100
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: COCO
num_classes: 1
trainsize: &trainsize 512
flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
find_unused_parameters: False
#####model
architecture: PETR
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/PETR_pretrained.pdparams
PETR:
backbone:
name: ResNet
depth: 50
variant: b
norm_type: bn
freeze_norm: True
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
lr_mult_list: [0.1, 0.1, 0.1, 0.1]
neck:
name: ChannelMapper
in_channels: [512, 1024, 2048]
kernel_size: 1
out_channels: 256
norm_type: "gn"
norm_groups: 32
act: None
num_outs: 4
bbox_head:
name: PETRHead
num_query: 300
num_classes: 1 # only person
in_channels: 2048
sync_cls_avg_factor: true
with_kpt_refine: true
transformer:
name: PETRTransformer
as_two_stage: true
encoder:
name: TransformerEncoder
encoder_layer:
name: TransformerEncoderLayer
d_model: 256
attn:
name: MSDeformableAttention
embed_dim: 256
num_heads: 8
num_levels: 4
num_points: 4
dim_feedforward: 1024
dropout: 0.1
num_layers: 6
decoder:
name: PETR_TransformerDecoder
num_layers: 3
return_intermediate: true
decoder_layer:
name: PETR_TransformerDecoderLayer
d_model: 256
dim_feedforward: 1024
dropout: 0.1
self_attn:
name: MultiHeadAttention
embed_dim: 256
num_heads: 8
dropout: 0.1
cross_attn:
name: MultiScaleDeformablePoseAttention
embed_dims: 256
num_heads: 8
num_levels: 4
num_points: 17
hm_encoder:
name: TransformerEncoder
encoder_layer:
name: TransformerEncoderLayer
d_model: 256
attn:
name: MSDeformableAttention
embed_dim: 256
num_heads: 8
num_levels: 1
num_points: 4
dim_feedforward: 1024
dropout: 0.1
num_layers: 1
refine_decoder:
name: PETR_DeformableDetrTransformerDecoder
num_layers: 2
return_intermediate: true
decoder_layer:
name: PETR_TransformerDecoderLayer
d_model: 256
dim_feedforward: 1024
dropout: 0.1
self_attn:
name: MultiHeadAttention
embed_dim: 256
num_heads: 8
dropout: 0.1
cross_attn:
name: MSDeformableAttention
embed_dim: 256
num_levels: 4
positional_encoding:
name: PositionEmbedding
num_pos_feats: 128
normalize: true
offset: -0.5
loss_cls:
name: Weighted_FocalLoss
use_sigmoid: true
gamma: 2.0
alpha: 0.25
loss_weight: 2.0
reduction: "mean"
loss_kpt:
name: L1Loss
loss_weight: 70.0
loss_kpt_rpn:
name: L1Loss
loss_weight: 70.0
loss_oks:
name: OKSLoss
loss_weight: 2.0
loss_hm:
name: CenterFocalLoss
loss_weight: 4.0
loss_kpt_refine:
name: L1Loss
loss_weight: 80.0
loss_oks_refine:
name: OKSLoss
loss_weight: 3.0
assigner:
name: PoseHungarianAssigner
cls_cost:
name: FocalLossCost
weight: 2.0
kpt_cost:
name: KptL1Cost
weight: 70.0
oks_cost:
name: OksCost
weight: 7.0
#####optimizer
LearningRate:
base_lr: 0.0002
schedulers:
- !PiecewiseDecay
milestones: [80]
gamma: 0.1
use_warmup: false
# - !LinearWarmup
# start_factor: 0.001
# steps: 1000
OptimizerBuilder:
clip_grad_by_norm: 0.1
optimizer:
type: AdamW
regularizer:
factor: 0.0001
type: L2
#####data
TrainDataset:
!KeypointBottomUpCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
return_mask: false
EvalDataset:
!KeypointBottomUpCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
test_mode: true
return_mask: false
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- Decode: {}
- PhotoMetricDistortion:
brightness_delta: 32
contrast_range: [0.5, 1.5]
saturation_range: [0.5, 1.5]
hue_delta: 18
- KeyPointFlip:
flip_prob: 0.5
flip_permutation: *flip_perm
- RandomAffine:
max_degree: 30
scale: [1.0, 1.0]
max_shift: 0.
trainsize: -1
- RandomSelect: { transforms1: [ RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600},
RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ]}
batch_transforms:
- NormalizeImage: {mean: *global_mean, std: *global_std, is_scale: True}
- PadGT: {pad_img: True, minimum_gtnum: 1}
- Permute: {}
batch_size: 2
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- PETR_Resize: {img_scale: [[800, 1333]], keep_ratio: True}
# - MultiscaleTestResize: {origin_target_size: [[800, 1333]], use_flip: false}
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
TestReader:
sample_transforms:
- Decode: {}
- EvalAffine:
size: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 1
weights: output/tinypose3d_human36M/model_final
epoch: 220
num_joints: &num_joints 24
pixel_std: &pixel_std 200
metric: Pose3DEval
num_classes: 1
train_height: &train_height 128
train_width: &train_width 128
trainsize: &trainsize [*train_width, *train_height]
#####model
architecture: TinyPose3DHRNet
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams
TinyPose3DHRNet:
backbone: LiteHRNet
post_process: HR3DNetPostProcess
fc_channel: 1024
num_joints: *num_joints
width: &width 40
loss: Pose3DLoss
LiteHRNet:
network_type: wider_naive
freeze_at: -1
freeze_norm: false
return_idx: [0]
Pose3DLoss:
weight_3d: 1.0
weight_2d: 0.0
#####optimizer
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
milestones: [17, 21]
gamma: 0.1
- !LinearWarmup
start_factor: 0.01
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!Pose3DDataset
dataset_dir: Human3.6M
image_dirs: ["Images"]
anno_list: ['Human3.6m_train.json']
num_joints: *num_joints
test_mode: False
EvalDataset:
!Pose3DDataset
dataset_dir: Human3.6M
image_dirs: ["Images"]
anno_list: ['Human3.6m_valid.json']
num_joints: *num_joints
test_mode: True
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [0.5, 30] #[prob, rotate range]
scale: [0.5, 0.25] #[prob, scale range]
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [0., 30]
scale: [0., 0.25]
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false
......@@ -82,7 +82,7 @@ MPII keypoint indexes:
```
{
'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
'joints': [
'gt_joints': [
[-1.0, -1.0],
[-1.0, -1.0],
[-1.0, -1.0],
......
......@@ -82,7 +82,7 @@ The following example takes a parsed annotation information to illustrate the co
```
{
'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
'joints': [
'gt_joints': [
[-1.0, -1.0],
[-1.0, -1.0],
[-1.0, -1.0],
......
......@@ -80,6 +80,7 @@ class KeypointBottomUpBaseDataset(DetDataset):
records = copy.deepcopy(self._get_imganno(idx))
records['image'] = cv2.imread(records['image_file'])
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
if 'mask' in records:
records['mask'] = (records['mask'] + 0).astype('uint8')
records = self.transform(records)
return records
......@@ -135,24 +136,37 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
num_joints,
transform=[],
shard=[0, 1],
test_mode=False):
test_mode=False,
return_mask=True,
return_bbox=True,
return_area=True,
return_class=True):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode)
self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
self.return_mask = return_mask
self.return_bbox = return_bbox
self.return_area = return_area
self.return_class = return_class
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds()
if not self.test_mode:
self.img_ids = [
img_id for img_id in self.img_ids
if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0
]
self.img_ids_tmp = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anno = self.coco.loadAnns(ann_ids)
anno = [obj for obj in anno if obj['iscrowd'] == 0]
if len(anno) == 0:
continue
self.img_ids_tmp.append(img_id)
self.img_ids = self.img_ids_tmp
blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))]
......@@ -199,21 +213,31 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
ann_ids = coco.getAnnIds(imgIds=img_id)
anno = coco.loadAnns(ann_ids)
mask = self._get_mask(anno, idx)
anno = [
obj for obj in anno
if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0
if obj['iscrowd'] == 0 and obj['num_keypoints'] > 0
]
db_rec = {}
joints, orgsize = self._get_joints(anno, idx)
db_rec['gt_joints'] = joints
db_rec['im_shape'] = orgsize
if self.return_bbox:
db_rec['gt_bbox'] = self._get_bboxs(anno, idx)
if self.return_class:
db_rec['gt_class'] = self._get_labels(anno, idx)
if self.return_area:
db_rec['gt_areas'] = self._get_areas(anno, idx)
if self.return_mask:
db_rec['mask'] = self._get_mask(anno, idx)
db_rec = {}
db_rec['im_id'] = img_id
db_rec['image_file'] = os.path.join(self.img_prefix,
self.id2name[img_id])
db_rec['mask'] = mask
db_rec['joints'] = joints
db_rec['im_shape'] = orgsize
return db_rec
......@@ -229,12 +253,41 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
np.array(obj['keypoints']).reshape([-1, 3])
img_info = self.coco.loadImgs(self.img_ids[idx])[0]
joints[..., 0] /= img_info['width']
joints[..., 1] /= img_info['height']
orgsize = np.array([img_info['height'], img_info['width']])
orgsize = np.array([img_info['height'], img_info['width'], 1])
return joints, orgsize
def _get_bboxs(self, anno, idx):
num_people = len(anno)
gt_bboxes = np.zeros((num_people, 4), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'bbox' in obj:
gt_bboxes[idx, :] = obj['bbox']
gt_bboxes[:, 2] += gt_bboxes[:, 0]
gt_bboxes[:, 3] += gt_bboxes[:, 1]
return gt_bboxes
def _get_labels(self, anno, idx):
num_people = len(anno)
gt_labels = np.zeros((num_people, 1), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'category_id' in obj:
catid = obj['category_id']
gt_labels[idx, 0] = self.catid2clsid[catid]
return gt_labels
def _get_areas(self, anno, idx):
num_people = len(anno)
gt_areas = np.zeros((num_people, ), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'area' in obj:
gt_areas[idx, ] = obj['area']
return gt_areas
def _get_mask(self, anno, idx):
"""Get ignore masks to mask out losses."""
coco = self.coco
......@@ -506,7 +559,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
'image_file': os.path.join(self.img_prefix, file_name),
'center': center,
'scale': scale,
'joints': joints,
'gt_joints': joints,
'joints_vis': joints_vis,
'im_id': im_id,
})
......@@ -570,7 +623,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
'center': center,
'scale': scale,
'score': score,
'joints': joints,
'gt_joints': joints,
'joints_vis': joints_vis,
})
......@@ -647,8 +700,8 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
(self.ann_info['num_joints'], 3), dtype=np.float32)
joints_vis = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
if 'joints' in a:
joints_ = np.array(a['joints'])
if 'gt_joints' in a:
joints_ = np.array(a['gt_joints'])
joints_[:, 0:2] = joints_[:, 0:2] - 1
joints_vis_ = np.array(a['joints_vis'])
assert len(joints_) == self.ann_info[
......@@ -664,7 +717,7 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
'im_id': im_id,
'center': c,
'scale': s,
'joints': joints,
'gt_joints': joints,
'joints_vis': joints_vis
})
print("number length: {}".format(len(gt_db)))
......
......@@ -1102,13 +1102,115 @@ class PadGT(BaseOperator):
1 means bbox, 0 means no bbox.
"""
def __init__(self, return_gt_mask=True):
def __init__(self, return_gt_mask=True, pad_img=False, minimum_gtnum=0):
super(PadGT, self).__init__()
self.return_gt_mask = return_gt_mask
self.pad_img = pad_img
self.minimum_gtnum = minimum_gtnum
def _impad(self, img: np.ndarray,
*,
shape = None,
padding = None,
pad_val = 0,
padding_mode = 'constant') -> np.ndarray:
"""Pad the given image to a certain shape or pad on all sides with
specified padding mode and padding value.
Args:
img (ndarray): Image to be padded.
shape (tuple[int]): Expected padding shape (h, w). Default: None.
padding (int or tuple[int]): Padding on each border. If a single int is
provided this is used to pad all borders. If tuple of length 2 is
provided this is the padding on left/right and top/bottom
respectively. If a tuple of length 4 is provided this is the
padding for the left, top, right and bottom borders respectively.
Default: None. Note that `shape` and `padding` can not be both
set.
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:
ndarray: The padded image.
"""
assert (shape is not None) ^ (padding is not None)
if shape is not None:
width = max(shape[1] - img.shape[1], 0)
height = max(shape[0] - img.shape[0], 0)
padding = (0, 0, int(width), int(height))
# check pad_val
import numbers
if isinstance(pad_val, tuple):
assert len(pad_val) == img.shape[-1]
elif not isinstance(pad_val, numbers.Number):
raise TypeError('pad_val must be a int or a tuple. '
f'But received {type(pad_val)}')
# check padding
if isinstance(padding, tuple) and len(padding) in [2, 4]:
if len(padding) == 2:
padding = (padding[0], padding[1], padding[0], padding[1])
elif isinstance(padding, numbers.Number):
padding = (padding, padding, padding, padding)
else:
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
f'But received {padding}')
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
border_type = {
'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_REFLECT_101,
'symmetric': cv2.BORDER_REFLECT
}
img = cv2.copyMakeBorder(
img,
padding[1],
padding[3],
padding[0],
padding[2],
border_type[padding_mode],
value=pad_val)
return img
def checkmaxshape(self, samples):
maxh, maxw = 0, 0
for sample in samples:
h,w = sample['im_shape']
if h>maxh:
maxh = h
if w>maxw:
maxw = w
return (maxh, maxw)
def __call__(self, samples, context=None):
num_max_boxes = max([len(s['gt_bbox']) for s in samples])
num_max_boxes = max(self.minimum_gtnum, num_max_boxes)
if self.pad_img:
maxshape = self.checkmaxshape(samples)
for sample in samples:
if self.pad_img:
img = sample['image']
padimg = self._impad(img, shape=maxshape)
sample['image'] = padimg
if self.return_gt_mask:
sample['pad_gt_mask'] = np.zeros(
(num_max_boxes, 1), dtype=np.float32)
......@@ -1142,6 +1244,17 @@ class PadGT(BaseOperator):
if num_gt > 0:
pad_diff[:num_gt] = sample['difficult']
sample['difficult'] = pad_diff
if 'gt_joints' in sample:
num_joints = sample['gt_joints'].shape[1]
pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32)
if num_gt > 0:
pad_gt_joints[:num_gt] = sample['gt_joints']
sample['gt_joints'] = pad_gt_joints
if 'gt_areas' in sample:
pad_gt_areas = np.zeros((num_max_boxes, 1), dtype=np.float32)
if num_gt > 0:
pad_gt_areas[:num_gt, 0] = sample['gt_areas']
sample['gt_areas'] = pad_gt_areas
return samples
......
......@@ -41,7 +41,7 @@ __all__ = [
'TopDownAffine', 'ToHeatmapsTopDown', 'ToHeatmapsTopDown_DARK',
'ToHeatmapsTopDown_UDP', 'TopDownEvalAffine',
'AugmentationbyInformantionDropping', 'SinglePoseAffine', 'NoiseJitter',
'FlipPose'
'FlipPose', 'PETR_Resize'
]
......@@ -65,38 +65,77 @@ class KeyPointFlip(object):
"""
def __init__(self, flip_permutation, hmsize, flip_prob=0.5):
def __init__(self, flip_permutation, hmsize=None, flip_prob=0.5):
super(KeyPointFlip, self).__init__()
assert isinstance(flip_permutation, Sequence)
self.flip_permutation = flip_permutation
self.flip_prob = flip_prob
self.hmsize = hmsize
def __call__(self, records):
image = records['image']
kpts_lst = records['joints']
mask_lst = records['mask']
flip = np.random.random() < self.flip_prob
if flip:
image = image[:, ::-1]
for idx, hmsize in enumerate(self.hmsize):
if len(mask_lst) > idx:
mask_lst[idx] = mask_lst[idx][:, ::-1]
def _flipjoints(self, records, sizelst):
'''
records['gt_joints'] is Sequence in higherhrnet
'''
if not ('gt_joints' in records and records['gt_joints'].size > 0):
return records
kpts_lst = records['gt_joints']
if isinstance(kpts_lst, Sequence):
for idx, hmsize in enumerate(sizelst):
if kpts_lst[idx].ndim == 3:
kpts_lst[idx] = kpts_lst[idx][:, self.flip_permutation]
else:
kpts_lst[idx] = kpts_lst[idx][self.flip_permutation]
kpts_lst[idx][..., 0] = hmsize - kpts_lst[idx][..., 0]
kpts_lst[idx] = kpts_lst[idx].astype(np.int64)
kpts_lst[idx][kpts_lst[idx][..., 0] >= hmsize, 2] = 0
kpts_lst[idx][kpts_lst[idx][..., 1] >= hmsize, 2] = 0
kpts_lst[idx][kpts_lst[idx][..., 0] < 0, 2] = 0
kpts_lst[idx][kpts_lst[idx][..., 1] < 0, 2] = 0
records['image'] = image
records['joints'] = kpts_lst
else:
hmsize = sizelst[0]
if kpts_lst.ndim == 3:
kpts_lst = kpts_lst[:, self.flip_permutation]
else:
kpts_lst = kpts_lst[self.flip_permutation]
kpts_lst[..., 0] = hmsize - kpts_lst[..., 0]
records['gt_joints'] = kpts_lst
return records
def _flipmask(self, records, sizelst):
if not 'mask' in records:
return records
mask_lst = records['mask']
for idx, hmsize in enumerate(sizelst):
if len(mask_lst) > idx:
mask_lst[idx] = mask_lst[idx][:, ::-1]
records['mask'] = mask_lst
return records
def _flipbbox(self, records, sizelst):
if not 'gt_bbox' in records:
return records
bboxes = records['gt_bbox']
hmsize = sizelst[0]
bboxes[:, 0::2] = hmsize - bboxes[:, 0::2][:, ::-1]
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, hmsize)
records['gt_bbox'] = bboxes
return records
def __call__(self, records):
flip = np.random.random() < self.flip_prob
if flip:
image = records['image']
image = image[:, ::-1]
records['image'] = image
if self.hmsize is None:
sizelst = [image.shape[1]]
else:
sizelst = self.hmsize
self._flipjoints(records, sizelst)
self._flipmask(records, sizelst)
self._flipbbox(records, sizelst)
return records
@register_keypointop
class RandomAffine(object):
......@@ -121,9 +160,10 @@ class RandomAffine(object):
max_degree=30,
scale=[0.75, 1.5],
max_shift=0.2,
hmsize=[128, 256],
hmsize=None,
trainsize=512,
scale_type='short'):
scale_type='short',
boldervalue=[114, 114, 114]):
super(RandomAffine, self).__init__()
self.max_degree = max_degree
self.min_scale = scale[0]
......@@ -132,8 +172,9 @@ class RandomAffine(object):
self.hmsize = hmsize
self.trainsize = trainsize
self.scale_type = scale_type
self.boldervalue = boldervalue
def _get_affine_matrix(self, center, scale, res, rot=0):
def _get_affine_matrix_old(self, center, scale, res, rot=0):
"""Generate transformation matrix."""
h = scale
t = np.zeros((3, 3), dtype=np.float32)
......@@ -159,21 +200,94 @@ class RandomAffine(object):
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def _get_affine_matrix(self, center, scale, res, rot=0):
"""Generate transformation matrix."""
w, h = scale
t = np.zeros((3, 3), dtype=np.float32)
t[0, 0] = float(res[0]) / w
t[1, 1] = float(res[1]) / h
t[0, 2] = res[0] * (-float(center[0]) / w + .5)
t[1, 2] = res[1] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if rot != 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3), dtype=np.float32)
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[0] / 2
t_mat[1, 2] = -res[1] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def _affine_joints_mask(self,
degree,
center,
roi_size,
dsize,
keypoints=None,
heatmap_mask=None,
gt_bbox=None):
kpts = None
mask = None
bbox = None
mask_affine_mat = self._get_affine_matrix(center, roi_size, dsize,
degree)[:2]
if heatmap_mask is not None:
mask = cv2.warpAffine(heatmap_mask, mask_affine_mat, dsize)
mask = ((mask / 255) > 0.5).astype(np.float32)
if keypoints is not None:
kpts = copy.deepcopy(keypoints)
kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(),
mask_affine_mat)
kpts[(kpts[..., 0]) > dsize[0], :] = 0
kpts[(kpts[..., 1]) > dsize[1], :] = 0
kpts[(kpts[..., 0]) < 0, :] = 0
kpts[(kpts[..., 1]) < 0, :] = 0
if gt_bbox is not None:
temp_bbox = gt_bbox[:, [0, 3, 2, 1]]
cat_bbox = np.concatenate((gt_bbox, temp_bbox), axis=-1)
gt_bbox_warped = warp_affine_joints(cat_bbox, mask_affine_mat)
bbox = np.zeros_like(gt_bbox)
bbox[:, 0] = gt_bbox_warped[:, 0::2].min(1).clip(0, dsize[0])
bbox[:, 2] = gt_bbox_warped[:, 0::2].max(1).clip(0, dsize[0])
bbox[:, 1] = gt_bbox_warped[:, 1::2].min(1).clip(0, dsize[1])
bbox[:, 3] = gt_bbox_warped[:, 1::2].max(1).clip(0, dsize[1])
return kpts, mask, bbox
def __call__(self, records):
image = records['image']
keypoints = records['joints']
shape = np.array(image.shape[:2][::-1])
keypoints = None
heatmap_mask = None
gt_bbox = None
if 'gt_joints' in records:
keypoints = records['gt_joints']
if 'mask' in records:
heatmap_mask = records['mask']
heatmap_mask *= 255
if 'gt_bbox' in records:
gt_bbox = records['gt_bbox']
degree = (np.random.random() * 2 - 1) * self.max_degree
shape = np.array(image.shape[:2][::-1])
center = center = np.array((np.array(shape) / 2))
aug_scale = np.random.random() * (self.max_scale - self.min_scale
) + self.min_scale
if self.scale_type == 'long':
scale = max(shape[0], shape[1]) / 1.0
scale = np.array([max(shape[0], shape[1]) / 1.0] * 2)
elif self.scale_type == 'short':
scale = min(shape[0], shape[1]) / 1.0
scale = np.array([min(shape[0], shape[1]) / 1.0] * 2)
elif self.scale_type == 'wh':
scale = shape
else:
raise ValueError('Unknown scale type: {}'.format(self.scale_type))
roi_size = aug_scale * scale
......@@ -181,44 +295,55 @@ class RandomAffine(object):
dy = int(0)
if self.max_shift > 0:
dx = np.random.randint(-self.max_shift * roi_size,
self.max_shift * roi_size)
dy = np.random.randint(-self.max_shift * roi_size,
self.max_shift * roi_size)
dx = np.random.randint(-self.max_shift * roi_size[0],
self.max_shift * roi_size[0])
dy = np.random.randint(-self.max_shift * roi_size[0],
self.max_shift * roi_size[1])
center += np.array([dx, dy])
input_size = 2 * center
if self.trainsize != -1:
dsize = self.trainsize
imgshape = (dsize, dsize)
else:
dsize = scale
imgshape = (shape.tolist())
keypoints[..., :2] *= shape
heatmap_mask *= 255
kpts_lst = []
mask_lst = []
image_affine_mat = self._get_affine_matrix(
center, roi_size, (self.trainsize, self.trainsize), degree)[:2]
image_affine_mat = self._get_affine_matrix(center, roi_size, dsize,
degree)[:2]
image = cv2.warpAffine(
image,
image_affine_mat, (self.trainsize, self.trainsize),
flags=cv2.INTER_LINEAR)
image_affine_mat,
imgshape,
flags=cv2.INTER_LINEAR,
borderValue=self.boldervalue)
if self.hmsize is None:
kpts, mask, gt_bbox = self._affine_joints_mask(
degree, center, roi_size, dsize, keypoints, heatmap_mask,
gt_bbox)
records['image'] = image
if kpts is not None: records['gt_joints'] = kpts
if mask is not None: records['mask'] = mask
if gt_bbox is not None: records['gt_bbox'] = gt_bbox
return records
kpts_lst = []
mask_lst = []
for hmsize in self.hmsize:
kpts = copy.deepcopy(keypoints)
mask_affine_mat = self._get_affine_matrix(
center, roi_size, (hmsize, hmsize), degree)[:2]
if heatmap_mask is not None:
mask = cv2.warpAffine(heatmap_mask, mask_affine_mat,
(hmsize, hmsize))
mask = ((mask / 255) > 0.5).astype(np.float32)
kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(),
mask_affine_mat)
kpts[np.trunc(kpts[..., 0]) >= hmsize, 2] = 0
kpts[np.trunc(kpts[..., 1]) >= hmsize, 2] = 0
kpts[np.trunc(kpts[..., 0]) < 0, 2] = 0
kpts[np.trunc(kpts[..., 1]) < 0, 2] = 0
kpts, mask, gt_bbox = self._affine_joints_mask(
degree, center, roi_size, [hmsize, hmsize], keypoints,
heatmap_mask, gt_bbox)
kpts_lst.append(kpts)
mask_lst.append(mask)
records['image'] = image
records['joints'] = kpts_lst
if 'gt_joints' in records:
records['gt_joints'] = kpts_lst
if 'mask' in records:
records['mask'] = mask_lst
if 'gt_bbox' in records:
records['gt_bbox'] = gt_bbox
return records
......@@ -251,8 +376,8 @@ class EvalAffine(object):
if mask is not None:
mask = cv2.warpAffine(mask, trans, size_resized)
records['mask'] = mask
if 'joints' in records:
del records['joints']
if 'gt_joints' in records:
del records['gt_joints']
records['image'] = image_resized
return records
......@@ -303,7 +428,7 @@ class TagGenerate(object):
self.num_joints = num_joints
def __call__(self, records):
kpts_lst = records['joints']
kpts_lst = records['gt_joints']
kpts = kpts_lst[0]
tagmap = np.zeros((self.max_people, self.num_joints, 4), dtype=np.int64)
inds = np.where(kpts[..., 2] > 0)
......@@ -315,7 +440,7 @@ class TagGenerate(object):
tagmap[p, j, 2] = visible[..., 0] # x
tagmap[p, j, 3] = 1
records['tagmap'] = tagmap
del records['joints']
del records['gt_joints']
return records
......@@ -349,7 +474,7 @@ class ToHeatmaps(object):
self.gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
def __call__(self, records):
kpts_lst = records['joints']
kpts_lst = records['gt_joints']
mask_lst = records['mask']
for idx, hmsize in enumerate(self.hmsize):
mask = mask_lst[idx]
......@@ -470,7 +595,7 @@ class RandomFlipHalfBodyTransform(object):
def __call__(self, records):
image = records['image']
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
c = records['center']
s = records['scale']
......@@ -493,7 +618,7 @@ class RandomFlipHalfBodyTransform(object):
joints, joints_vis, image.shape[1], self.flip_pairs)
c[0] = image.shape[1] - c[0] - 1
records['image'] = image
records['joints'] = joints
records['gt_joints'] = joints
records['joints_vis'] = joints_vis
records['center'] = c
records['scale'] = s
......@@ -553,7 +678,7 @@ class AugmentationbyInformantionDropping(object):
def __call__(self, records):
img = records['image']
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
if np.random.rand() < self.prob_cutout:
img = self._cutout(img, joints, joints_vis)
......@@ -581,7 +706,7 @@ class TopDownAffine(object):
def __call__(self, records):
image = records['image']
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
rot = records['rotate'] if "rotate" in records else 0
if self.use_udp:
......@@ -606,7 +731,7 @@ class TopDownAffine(object):
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
records['image'] = image
records['joints'] = joints
records['gt_joints'] = joints
return records
......@@ -842,7 +967,7 @@ class ToHeatmapsTopDown(object):
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
image_size = np.array(
......@@ -885,7 +1010,7 @@ class ToHeatmapsTopDown(object):
0]:g_y[1], g_x[0]:g_x[1]]
records['target'] = target
records['target_weight'] = target_weight
del records['joints'], records['joints_vis']
del records['gt_joints'], records['joints_vis']
return records
......@@ -910,7 +1035,7 @@ class ToHeatmapsTopDown_DARK(object):
self.sigma = sigma
def __call__(self, records):
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
image_size = np.array(
......@@ -943,7 +1068,7 @@ class ToHeatmapsTopDown_DARK(object):
(x - mu_x)**2 + (y - mu_y)**2) / (2 * self.sigma**2))
records['target'] = target
records['target_weight'] = target_weight
del records['joints'], records['joints_vis']
del records['gt_joints'], records['joints_vis']
return records
......@@ -972,7 +1097,7 @@ class ToHeatmapsTopDown_UDP(object):
self.sigma = sigma
def __call__(self, records):
joints = records['joints']
joints = records['gt_joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
image_size = np.array(
......@@ -1017,6 +1142,472 @@ class ToHeatmapsTopDown_UDP(object):
0]:g_y[1], g_x[0]:g_x[1]]
records['target'] = target
records['target_weight'] = target_weight
del records['joints'], records['joints_vis']
del records['gt_joints'], records['joints_vis']
return records
from typing import Optional, Tuple, Union, List
import numbers
def _scale_size(
size: Tuple[int, int],
scale: Union[float, int, tuple], ) -> Tuple[int, int]:
"""Rescale a size by a ratio.
Args:
size (tuple[int]): (w, h).
scale (float | tuple(float)): Scaling factor.
Returns:
tuple[int]: scaled size.
"""
if isinstance(scale, (float, int)):
scale = (scale, scale)
w, h = size
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
def rescale_size(old_size: tuple,
scale: Union[float, int, tuple],
return_scale: bool=False) -> tuple:
"""Calculate the new size to be rescaled to.
Args:
old_size (tuple[int]): The old size (w, h) of image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image size.
Returns:
tuple[int]: The new rescaled image size.
"""
w, h = old_size
if isinstance(scale, (float, int)):
if scale <= 0:
raise ValueError(f'Invalid scale {scale}, must be positive.')
scale_factor = scale
elif isinstance(scale, list):
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
else:
raise TypeError(
f'Scale must be a number or tuple of int, but got {type(scale)}')
new_size = _scale_size((w, h), scale_factor)
if return_scale:
return new_size, scale_factor
else:
return new_size
def imrescale(img: np.ndarray,
scale: Union[float, Tuple[int, int]],
return_scale: bool=False,
interpolation: str='bilinear',
backend: Optional[str]=None) -> Union[np.ndarray, Tuple[
np.ndarray, float]]:
"""Resize image while keeping the aspect ratio.
Args:
img (ndarray): The input image.
scale (float | tuple[int]): The scaling factor or maximum size.
If it is a float number, then the image will be rescaled by this
factor, else if it is a tuple of 2 integers, then the image will
be rescaled as large as possible within the scale.
return_scale (bool): Whether to return the scaling factor besides the
rescaled image.
interpolation (str): Same as :func:`resize`.
backend (str | None): Same as :func:`resize`.
Returns:
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(
img, new_size, interpolation=interpolation, backend=backend)
if return_scale:
return rescaled_img, scale_factor
else:
return rescaled_img
def imresize(
img: np.ndarray,
size: Tuple[int, int],
return_scale: bool=False,
interpolation: str='bilinear',
out: Optional[np.ndarray]=None,
backend: Optional[str]=None,
interp=cv2.INTER_LINEAR, ) -> Union[Tuple[np.ndarray, float, float],
np.ndarray]:
"""Resize image to a given size.
Args:
img (ndarray): The input image.
size (tuple[int]): Target size (w, h).
return_scale (bool): Whether to return `w_scale` and `h_scale`.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend.
out (ndarray): The output destination.
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used. Default: None.
Returns:
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
`resized_img`.
"""
h, w = img.shape[:2]
if backend is None:
backend = imread_backend
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
f"Supported backends are 'cv2', 'pillow'")
if backend == 'pillow':
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
pil_image = Image.fromarray(img)
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
resized_img = np.array(pil_image)
else:
resized_img = cv2.resize(img, size, dst=out, interpolation=interp)
if not return_scale:
return resized_img
else:
w_scale = size[0] / w
h_scale = size[1] / h
return resized_img, w_scale, h_scale
class PETR_Resize:
"""Resize images & bbox & mask.
This transform resizes the input image to some scale. Bboxes and masks are
then resized with the same scale factor. If the input dict contains the key
"scale", then the scale in the input dict is used, otherwise the specified
scale in the init method is used. If the input dict contains the key
"scale_factor" (if MultiScaleFlipAug does not give img_scale but
scale_factor), the actual scale will be computed by image shape and
scale_factor.
`img_scale` can either be a tuple (single-scale) or a list of tuple
(multi-scale). There are 3 multiscale modes:
- ``ratio_range is not None``: randomly sample a ratio from the ratio \
range and multiply it with the image scale.
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
sample a scale from the multiscale range.
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
sample a scale from multiple scales.
Args:
img_scale (tuple or list[tuple]): Images scales for resizing.
multiscale_mode (str): Either "range" or "value".
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend.
override (bool, optional): Whether to override `scale` and
`scale_factor` so as to call resize twice. Default False. If True,
after the first resizing, the existed `scale` and `scale_factor`
will be ignored so the second resizing can be allowed.
This option is a work-around for multiple times of resize in DETR.
Defaults to False.
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
bbox_clip_border=True,
backend='cv2',
interpolation='bilinear',
override=False,
keypoint_clip_border=True):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert isinstance(self.img_scale, list)
if ratio_range is not None:
# mode 1: given a scale and a range of image ratio
assert len(self.img_scale) == 1
else:
# mode 2: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.backend = backend
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
# TODO: refactor the override option in Resize
self.interpolation = interpolation
self.override = override
self.bbox_clip_border = bbox_clip_border
self.keypoint_clip_border = keypoint_clip_border
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
where ``img_scale`` is the selected image scale and \
``scale_idx`` is the selected index in the given candidates.
"""
assert isinstance(img_scales, list)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and upper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where \
``img_scale`` is sampled scale and None is just a placeholder \
to be consistent with :func:`random_select`.
"""
assert isinstance(img_scales, list) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long), max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short), max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (list): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where \
``scale`` is sampled ratio multiplied with ``img_scale`` and \
None is just a placeholder to be consistent with \
:func:`random_select`.
"""
assert isinstance(img_scale, list) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into \
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(self.img_scale[0],
self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
for key in ['image'] if 'image' in results else []:
if self.keep_ratio:
img, scale_factor = imrescale(
results[key],
results['scale'],
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
# the w_scale and h_scale has minor difference
# a real fix should be done in the imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results[key].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = imresize(
results[key],
results['scale'],
return_scale=True,
interpolation=self.interpolation,
backend=self.backend)
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
results['im_shape'] = np.array(img.shape)
# in case that there is no padding
results['pad_shape'] = img.shape
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
# img_pad = self.impad(img, shape=results['scale'])
results[key] = img
def _resize_bboxes(self, results):
"""Resize bounding boxes with ``results['scale_factor']``."""
for key in ['gt_bbox'] if 'gt_bbox' in results else []:
bboxes = results[key] * results['scale_factor']
if self.bbox_clip_border:
img_shape = results['im_shape']
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
results[key] = bboxes
def _resize_masks(self, results):
"""Resize masks with ``results['scale']``"""
for key in ['mask'] if 'mask' in results else []:
if results[key] is None:
continue
if self.keep_ratio:
results[key] = results[key].rescale(results['scale'])
else:
results[key] = results[key].resize(results['im_shape'][:2])
def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in ['seg'] if 'seg' in results else []:
if self.keep_ratio:
gt_seg = imrescale(
results[key],
results['scale'],
interpolation='nearest',
backend=self.backend)
else:
gt_seg = imresize(
results[key],
results['scale'],
interpolation='nearest',
backend=self.backend)
results[key] = gt_seg
def _resize_keypoints(self, results):
"""Resize keypoints with ``results['scale_factor']``."""
for key in ['gt_joints'] if 'gt_joints' in results else []:
keypoints = results[key].copy()
keypoints[..., 0] = keypoints[..., 0] * results['scale_factor'][0]
keypoints[..., 1] = keypoints[..., 1] * results['scale_factor'][1]
if self.keypoint_clip_border:
img_shape = results['im_shape']
keypoints[..., 0] = np.clip(keypoints[..., 0], 0, img_shape[1])
keypoints[..., 1] = np.clip(keypoints[..., 1], 0, img_shape[0])
results[key] = keypoints
def _resize_areas(self, results):
"""Resize mask areas with ``results['scale_factor']``."""
for key in ['gt_areas'] if 'gt_areas' in results else []:
areas = results[key].copy()
areas = areas * results['scale_factor'][0] * results[
'scale_factor'][1]
results[key] = areas
def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'im_shape', 'pad_shape', 'scale_factor', \
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
if 'scale_factor' in results:
img_shape = results['image'].shape[:2]
scale_factor = results['scale_factor']
assert isinstance(scale_factor, float)
results['scale'] = tuple(
[int(x * scale_factor) for x in img_shape][::-1])
else:
self._random_scale(results)
else:
if not self.override:
assert 'scale_factor' not in results, (
'scale and scale_factor cannot be both set.')
else:
results.pop('scale')
if 'scale_factor' in results:
results.pop('scale_factor')
self._random_scale(results)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
self._resize_keypoints(results)
self._resize_areas(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'multiscale_mode={self.multiscale_mode}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
repr_str += f'keypoint_clip_border={self.keypoint_clip_border})'
return repr_str
......@@ -594,6 +594,108 @@ class RandomDistort(BaseOperator):
return sample
@register_op
class PhotoMetricDistortion(BaseOperator):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
super(PhotoMetricDistortion, self).__init__()
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def apply(self, results, context=None):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = results['image']
img = img.astype(np.float32)
# random brightness
if np.random.randint(2):
delta = np.random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = np.random.randint(2)
if mode == 1:
if np.random.randint(2):
alpha = np.random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# random saturation
if np.random.randint(2):
img[..., 1] *= np.random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if np.random.randint(2):
img[..., 0] += np.random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
# random contrast
if mode == 0:
if np.random.randint(2):
alpha = np.random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if np.random.randint(2):
img = img[..., np.random.permutation(3)]
results['image'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
repr_str += 'contrast_range='
repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
repr_str += 'saturation_range='
repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
repr_str += f'hue_delta={self.hue_delta})'
return repr_str
@register_op
class AutoAugment(BaseOperator):
def __init__(self, autoaug_type="v1"):
......@@ -771,6 +873,19 @@ class Resize(BaseOperator):
bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
return bbox
def apply_area(self, area, scale):
im_scale_x, im_scale_y = scale
return area * im_scale_x * im_scale_y
def apply_joints(self, joints, scale, size):
im_scale_x, im_scale_y = scale
resize_w, resize_h = size
joints[..., 0] *= im_scale_x
joints[..., 1] *= im_scale_y
joints[..., 0] = np.clip(joints[..., 0], 0, resize_w)
joints[..., 1] = np.clip(joints[..., 1], 0, resize_h)
return joints
def apply_segm(self, segms, im_size, scale):
def _resize_poly(poly, im_scale_x, im_scale_y):
resized_poly = np.array(poly).astype('float32')
......@@ -833,8 +948,8 @@ class Resize(BaseOperator):
im_scale = min(target_size_min / im_size_min,
target_size_max / im_size_max)
resize_h = im_scale * float(im_shape[0])
resize_w = im_scale * float(im_shape[1])
resize_h = int(im_scale * float(im_shape[0]) + 0.5)
resize_w = int(im_scale * float(im_shape[1]) + 0.5)
im_scale_x = im_scale
im_scale_y = im_scale
......@@ -878,6 +993,11 @@ class Resize(BaseOperator):
[im_scale_x, im_scale_y],
[resize_w, resize_h])
# apply areas
if 'gt_areas' in sample:
sample['gt_areas'] = self.apply_area(sample['gt_areas'],
[im_scale_x, im_scale_y])
# apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2],
......@@ -911,6 +1031,11 @@ class Resize(BaseOperator):
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(sample['gt_joints'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
return sample
......@@ -1362,7 +1487,8 @@ class RandomCrop(BaseOperator):
num_attempts=50,
allow_no_crop=True,
cover_all_box=False,
is_mask_crop=False):
is_mask_crop=False,
ioumode="iou"):
super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
......@@ -1371,6 +1497,7 @@ class RandomCrop(BaseOperator):
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop
self.ioumode = ioumode
def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
......@@ -1516,6 +1643,11 @@ class RandomCrop(BaseOperator):
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
if self.ioumode == "iof":
iou = self._gtcropiou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
elif self.ioumode == "iou":
iou = self._iou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
......@@ -1582,6 +1714,10 @@ class RandomCrop(BaseOperator):
sample['difficult'] = np.take(
sample['difficult'], valid_ids, axis=0)
if 'gt_joints' in sample:
sample['gt_joints'] = self._crop_joints(sample['gt_joints'],
crop_box)
return sample
return sample
......@@ -1596,6 +1732,16 @@ class RandomCrop(BaseOperator):
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def _gtcropiou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_a + 1e-10)
def _crop_box_with_center_constraint(self, box, crop):
cropped_box = box.copy()
......@@ -1620,6 +1766,16 @@ class RandomCrop(BaseOperator):
x1, y1, x2, y2 = crop
return segm[:, y1:y2, x1:x2]
def _crop_joints(self, joints, crop):
x1, y1, x2, y2 = crop
joints[joints[..., 0] > x2, :] = 0
joints[joints[..., 1] > y2, :] = 0
joints[joints[..., 0] < x1, :] = 0
joints[joints[..., 1] < y1, :] = 0
joints[..., 0] -= x1
joints[..., 1] -= y1
return joints
@register_op
class RandomScaledCrop(BaseOperator):
......@@ -1648,8 +1804,8 @@ class RandomScaledCrop(BaseOperator):
random_dim = int(dim * random_scale)
dim_max = max(h, w)
scale = random_dim / dim_max
resize_w = w * scale
resize_h = h * scale
resize_w = int(w * scale + 0.5)
resize_h = int(h * scale + 0.5)
offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
......@@ -2316,8 +2472,7 @@ class RandomResizeCrop(BaseOperator):
is_mask_crop(bool): whether crop the segmentation.
"""
def __init__(
self,
def __init__(self,
resizes,
cropsizes,
prob=0.5,
......@@ -2328,13 +2483,15 @@ class RandomResizeCrop(BaseOperator):
cover_all_box=False,
allow_no_crop=False,
thresholds=[0.3, 0.5, 0.7],
is_mask_crop=False, ):
is_mask_crop=False,
ioumode="iou"):
super(RandomResizeCrop, self).__init__()
self.resizes = resizes
self.cropsizes = cropsizes
self.prob = prob
self.mode = mode
self.ioumode = ioumode
self.resizer = Resize(0, keep_ratio=keep_ratio, interp=interp)
self.croper = RandomCrop(
......@@ -2389,6 +2546,11 @@ class RandomResizeCrop(BaseOperator):
crop_x = random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
if self.ioumode == "iof":
iou = self._gtcropiou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
elif self.ioumode == "iou":
iou = self._iou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
......@@ -2447,6 +2609,14 @@ class RandomResizeCrop(BaseOperator):
if 'is_crowd' in sample:
sample['is_crowd'] = np.take(
sample['is_crowd'], valid_ids, axis=0)
if 'gt_areas' in sample:
sample['gt_areas'] = np.take(
sample['gt_areas'], valid_ids, axis=0)
if 'gt_joints' in sample:
gt_joints = self._crop_joints(sample['gt_joints'], crop_box)
sample['gt_joints'] = gt_joints[valid_ids]
return sample
return sample
......@@ -2479,8 +2649,8 @@ class RandomResizeCrop(BaseOperator):
im_scale = max(target_size_min / im_size_min,
target_size_max / im_size_max)
resize_h = im_scale * float(im_shape[0])
resize_w = im_scale * float(im_shape[1])
resize_h = int(im_scale * float(im_shape[0]) + 0.5)
resize_w = int(im_scale * float(im_shape[1]) + 0.5)
im_scale_x = im_scale
im_scale_y = im_scale
......@@ -2540,6 +2710,11 @@ class RandomResizeCrop(BaseOperator):
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(sample['gt_joints'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
return sample
......@@ -2612,10 +2787,10 @@ class RandomShortSideResize(BaseOperator):
if w < h:
ow = size
oh = int(size * h / w)
oh = int(round(size * h / w))
else:
oh = size
ow = int(size * w / h)
ow = int(round(size * w / h))
return (ow, oh)
......@@ -2672,6 +2847,16 @@ class RandomShortSideResize(BaseOperator):
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(
sample['gt_joints'], [im_scale_x, im_scale_y], target_size)
# apply areas
if 'gt_areas' in sample:
sample['gt_areas'] = self.apply_area(sample['gt_areas'],
[im_scale_x, im_scale_y])
return sample
def apply_bbox(self, bbox, scale, size):
......@@ -2683,6 +2868,23 @@ class RandomShortSideResize(BaseOperator):
bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
return bbox.astype('float32')
def apply_joints(self, joints, scale, size):
im_scale_x, im_scale_y = scale
resize_w, resize_h = size
joints[..., 0] *= im_scale_x
joints[..., 1] *= im_scale_y
# joints[joints[..., 0] >= resize_w, :] = 0
# joints[joints[..., 1] >= resize_h, :] = 0
# joints[joints[..., 0] < 0, :] = 0
# joints[joints[..., 1] < 0, :] = 0
joints[..., 0] = np.clip(joints[..., 0], 0, resize_w)
joints[..., 1] = np.clip(joints[..., 1], 0, resize_h)
return joints
def apply_area(self, area, scale):
im_scale_x, im_scale_y = scale
return area * im_scale_x * im_scale_y
def apply_segm(self, segms, im_size, scale):
def _resize_poly(poly, im_scale_x, im_scale_y):
resized_poly = np.array(poly).astype('float32')
......@@ -2730,6 +2932,44 @@ class RandomShortSideResize(BaseOperator):
return self.resize(sample, target_size, self.max_size, interp)
@register_op
class RandomShortSideRangeResize(RandomShortSideResize):
def __init__(self, scales, interp=cv2.INTER_LINEAR, random_interp=False):
"""
Resize the image randomly according to the short side. If max_size is not None,
the long side is scaled according to max_size. The whole process will be keep ratio.
Args:
short_side_sizes (list|tuple): Image target short side size.
interp (int): The interpolation method.
random_interp (bool): Whether random select interpolation method.
"""
super(RandomShortSideRangeResize, self).__init__(scales, None, interp,
random_interp)
assert isinstance(scales,
Sequence), "short_side_sizes must be List or Tuple"
self.scales = scales
def random_sample(self, img_scales):
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long), max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short), max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale
def apply(self, sample, context=None):
long_edge, short_edge = self.random_sample(self.short_side_sizes)
# print("target size:{}".format((long_edge, short_edge)))
interp = random.choice(
self.interps) if self.random_interp else self.interp
return self.resize(sample, short_edge, long_edge, interp)
@register_op
class RandomSizeCrop(BaseOperator):
"""
......@@ -2805,6 +3045,9 @@ class RandomSizeCrop(BaseOperator):
sample['is_crowd'] = sample['is_crowd'][keep_index] if len(
keep_index) > 0 else np.zeros(
[0, 1], dtype=np.float32)
if 'gt_areas' in sample:
sample['gt_areas'] = np.take(
sample['gt_areas'], keep_index, axis=0)
image_shape = sample['image'].shape[:2]
sample['image'] = self.paddle_crop(sample['image'], *region)
......@@ -2826,6 +3069,12 @@ class RandomSizeCrop(BaseOperator):
if keep_index is not None and len(keep_index) > 0:
sample['gt_segm'] = sample['gt_segm'][keep_index]
if 'gt_joints' in sample:
gt_joints = self._crop_joints(sample['gt_joints'], region)
sample['gt_joints'] = gt_joints
if keep_index is not None:
sample['gt_joints'] = sample['gt_joints'][keep_index]
return sample
def apply_bbox(self, bbox, region):
......@@ -2836,6 +3085,19 @@ class RandomSizeCrop(BaseOperator):
crop_bbox = crop_bbox.clip(min=0)
return crop_bbox.reshape([-1, 4]).astype('float32')
def _crop_joints(self, joints, region):
y1, x1, h, w = region
x2 = x1 + w
y2 = y1 + h
# x1, y1, x2, y2 = crop
joints[..., 0] -= x1
joints[..., 1] -= y1
joints[joints[..., 0] > w, :] = 0
joints[joints[..., 1] > h, :] = 0
joints[joints[..., 0] < 0, :] = 0
joints[joints[..., 1] < 0, :] = 0
return joints
def apply_segm(self, segms, region, image_shape):
def _crop_poly(segm, crop):
xmin, ymin, xmax, ymax = crop
......
......@@ -72,3 +72,4 @@ from .yolof import *
from .pose3d_metro import *
from .centertrack import *
from .queryinst import *
from .keypoint_petr import *
......@@ -394,6 +394,7 @@ class TinyPose3DHRNet(BaseArch):
def __init__(self,
width,
num_joints,
fc_channel=768,
backbone='HRNet',
loss='KeyPointRegressionMSELoss',
post_process=TinyPose3DPostProcess):
......@@ -411,21 +412,13 @@ class TinyPose3DHRNet(BaseArch):
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
self.final_conv_new = L.Conv2d(
width, num_joints * 32, 1, 1, 0, bias=True)
self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3)
self.fc1 = paddle.nn.Linear(768, 256)
self.fc1 = paddle.nn.Linear(fc_channel, 256)
self.act1 = paddle.nn.ReLU()
self.fc2 = paddle.nn.Linear(256, 64)
self.act2 = paddle.nn.ReLU()
self.fc3 = paddle.nn.Linear(64, 3)
# for human3.6M
self.fc1_1 = paddle.nn.Linear(3136, 1024)
self.fc2_1 = paddle.nn.Linear(1024, 256)
self.fc3_1 = paddle.nn.Linear(256, 3)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/detectors/petr.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
from .. import layers as L
__all__ = ['PETR']
@register
class PETR(BaseArch):
__category__ = 'architecture'
__inject__ = ['backbone', 'neck', 'bbox_head']
def __init__(self,
backbone='ResNet',
neck='ChannelMapper',
bbox_head='PETRHead'):
"""
PETR, see https://openaccess.thecvf.com/content/CVPR2022/papers/Shi_End-to-End_Multi-Person_Pose_Estimation_With_Transformers_CVPR_2022_paper.pdf
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck between backbone and head
bbox_head (nn.Layer): model output and loss
"""
super(PETR, self).__init__()
self.backbone = backbone
if neck is not None:
self.with_neck = True
self.neck = neck
self.bbox_head = bbox_head
self.deploy = False
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def get_inputs(self):
img_metas = []
gt_bboxes = []
gt_labels = []
gt_keypoints = []
gt_areas = []
pad_gt_mask = self.inputs['pad_gt_mask'].astype("bool").squeeze(-1)
for idx, im_shape in enumerate(self.inputs['im_shape']):
img_meta = {
'img_shape': im_shape.astype("int32").tolist() + [1, ],
'batch_input_shape': self.inputs['image'].shape[-2:],
'image_name': self.inputs['image_file'][idx]
}
img_metas.append(img_meta)
if (not pad_gt_mask[idx].any()):
gt_keypoints.append(self.inputs['gt_joints'][idx][:1])
gt_labels.append(self.inputs['gt_class'][idx][:1])
gt_bboxes.append(self.inputs['gt_bbox'][idx][:1])
gt_areas.append(self.inputs['gt_areas'][idx][:1])
continue
gt_keypoints.append(self.inputs['gt_joints'][idx][pad_gt_mask[idx]])
gt_labels.append(self.inputs['gt_class'][idx][pad_gt_mask[idx]])
gt_bboxes.append(self.inputs['gt_bbox'][idx][pad_gt_mask[idx]])
gt_areas.append(self.inputs['gt_areas'][idx][pad_gt_mask[idx]])
return img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas
def get_loss(self):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
gt_keypoints (list[Tensor]): Each item are the truth keypoints for
each image in [p^{1}_x, p^{1}_y, p^{1}_v, ..., p^{K}_x,
p^{K}_y, p^{K}_v] format.
gt_areas (list[Tensor]): mask areas corresponding to each box.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas = self.get_inputs(
)
gt_bboxes_ignore = getattr(self.inputs, 'gt_bboxes_ignore', None)
x = self.extract_feat(self.inputs)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_keypoints, gt_areas,
gt_bboxes_ignore)
loss = 0
for k, v in losses.items():
loss += v
losses['loss'] = loss
return losses
def get_pred_numpy(self):
"""Used for computing network flops.
"""
img = self.inputs['image']
batch_size, _, height, width = img.shape
dummy_img_metas = [
dict(
batch_input_shape=(height, width),
img_shape=(height, width, 3),
scale_factor=(1., 1., 1., 1.)) for _ in range(batch_size)
]
x = self.extract_feat(img)
outs = self.bbox_head(x, img_metas=dummy_img_metas)
bbox_list = self.bbox_head.get_bboxes(
*outs, dummy_img_metas, rescale=True)
return bbox_list
def get_pred(self):
"""
"""
img = self.inputs['image']
batch_size, _, height, width = img.shape
img_metas = [
dict(
batch_input_shape=(height, width),
img_shape=(height, width, 3),
scale_factor=self.inputs['scale_factor'][i])
for i in range(batch_size)
]
kptpred = self.simple_test(
self.inputs, img_metas=img_metas, rescale=True)
keypoints = kptpred[0][1][0]
bboxs = kptpred[0][0][0]
keypoints[..., 2] = bboxs[:, None, 4]
res_lst = [[keypoints, bboxs[:, 4]]]
outputs = {'keypoint': res_lst}
return outputs
def simple_test(self, inputs, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
inputs (list[paddle.Tensor]): List of multiple images.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox and keypoint results of each image
and classes. The outer list corresponds to each image.
The inner list corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'
feat = self.extract_feat(inputs)
results_list = self.bbox_head.simple_test(
feat, img_metas, rescale=rescale)
bbox_kpt_results = [
self.bbox_kpt2result(det_bboxes, det_labels, det_kpts,
self.bbox_head.num_classes)
for det_bboxes, det_labels, det_kpts in results_list
]
return bbox_kpt_results
def bbox_kpt2result(self, bboxes, labels, kpts, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (paddle.Tensor | np.ndarray): shape (n, 5).
labels (paddle.Tensor | np.ndarray): shape (n, ).
kpts (paddle.Tensor | np.ndarray): shape (n, K, 3).
num_classes (int): class number, including background class.
Returns:
list(ndarray): bbox and keypoint results of each class.
"""
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)], \
[np.zeros((0, kpts.size(1), 3), dtype=np.float32)
for i in range(num_classes)]
else:
if isinstance(bboxes, paddle.Tensor):
bboxes = bboxes.numpy()
labels = labels.numpy()
kpts = kpts.numpy()
return [bboxes[labels == i, :] for i in range(num_classes)], \
[kpts[labels == i, :, :] for i in range(num_classes)]
......@@ -31,3 +31,5 @@ from .fcosr_assigner import *
from .rotated_task_aligned_assigner import *
from .task_aligned_assigner_cr import *
from .uniform_assigner import *
from .hungarian_assigner import *
from .pose_utils import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
import paddle
from ppdet.core.workspace import register
__all__ = ['PoseHungarianAssigner', 'PseudoSampler']
class AssignResult:
"""Stores assignments between predicted and truth boxes.
Attributes:
num_gts (int): the number of truth boxes considered when computing this
assignment
gt_inds (LongTensor): for each predicted box indicates the 1-based
index of the assigned truth box. 0 means unassigned and -1 means
ignore.
max_overlaps (FloatTensor): the iou between the predicted box and its
assigned truth box.
labels (None | LongTensor): If specified, for each predicted box
indicates the category label of the assigned truth box.
"""
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.max_overlaps = max_overlaps
self.labels = labels
# Interface for possible user-defined properties
self._extra_properties = {}
@property
def num_preds(self):
"""int: the number of predictions in this assignment"""
return len(self.gt_inds)
def set_extra_property(self, key, value):
"""Set user-defined new property."""
assert key not in self.info
self._extra_properties[key] = value
def get_extra_property(self, key):
"""Get user-defined property."""
return self._extra_properties.get(key, None)
@property
def info(self):
"""dict: a dictionary of info about the object"""
basic_info = {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}
basic_info.update(self._extra_properties)
return basic_info
@register
class PoseHungarianAssigner:
"""Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost, regression L1 cost and regression oks cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt.
- positive integer: positive sample, index (1-based) of assigned gt.
Args:
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
kpt_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0.
oks_weight (int | float, optional): The scale factor for regression
oks cost. Default 1.0.
"""
__inject__ = ['cls_cost', 'kpt_cost', 'oks_cost']
def __init__(self,
cls_cost='ClassificationCost',
kpt_cost='KptL1Cost',
oks_cost='OksCost'):
self.cls_cost = cls_cost
self.kpt_cost = kpt_cost
self.oks_cost = oks_cost
def assign(self,
cls_pred,
kpt_pred,
gt_labels,
gt_keypoints,
gt_areas,
img_meta,
eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
kpt_pred (Tensor): Predicted keypoints with normalized coordinates
(x_{i}, y_{i}), which are all in range [0, 1]. Shape
[num_query, K*2].
gt_labels (Tensor): Label of `gt_keypoints`, shape (num_gt,).
gt_keypoints (Tensor): Ground truth keypoints with unnormalized
coordinates [p^{1}_x, p^{1}_y, p^{1}_v, ..., \
p^{K}_x, p^{K}_y, p^{K}_v]. Shape [num_gt, K*3].
gt_areas (Tensor): Ground truth mask areas, shape (num_gt,).
img_meta (dict): Meta information for current image.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
num_gts, num_kpts = gt_keypoints.shape[0], kpt_pred.shape[0]
if not gt_keypoints.astype('bool').any():
num_gts = 0
# 1. assign -1 by default
assigned_gt_inds = paddle.full((num_kpts, ), -1, dtype="int64")
assigned_labels = paddle.full((num_kpts, ), -1, dtype="int64")
if num_gts == 0 or num_kpts == 0:
# No ground truth or keypoints, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
img_h, img_w, _ = img_meta['img_shape']
factor = paddle.to_tensor(
[img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype).reshape(
(1, -1))
# 2. compute the weighted costs
# classification cost
cls_cost = self.cls_cost(cls_pred, gt_labels)
# keypoint regression L1 cost
gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1,
3))
valid_kpt_flag = gt_keypoints_reshape[..., -1]
kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
2))
normalize_gt_keypoints = gt_keypoints_reshape[
..., :2] / factor[:, :2].unsqueeze(0)
kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints,
valid_kpt_flag)
# keypoint OKS cost
kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
2))
kpt_pred_tmp = kpt_pred_tmp * factor[:, :2].unsqueeze(0)
oks_cost = self.oks_cost(kpt_pred_tmp, gt_keypoints_reshape[..., :2],
valid_kpt_flag, gt_areas)
# weighted sum of above three costs
cost = cls_cost + kpt_cost + oks_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = paddle.to_tensor(matched_row_inds)
matched_col_inds = paddle.to_tensor(matched_col_inds)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][
..., 0].astype("int64")
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
class SamplingResult:
"""Bbox sampling result.
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
if pos_inds.size > 0:
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = paddle.zeros(
gt_bboxes.shape, dtype=gt_bboxes.dtype).reshape((-1, 4))
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.reshape((-1, 4))
self.pos_gt_bboxes = paddle.index_select(
gt_bboxes,
self.pos_assigned_gt_inds.astype('int64'),
axis=0)
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""paddle.Tensor: concatenated positive and negative boxes"""
return paddle.concat([self.pos_bboxes, self.neg_bboxes])
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@register
class PseudoSampler:
"""A pseudo sampler that does not do sampling actually."""
def __init__(self, **kwargs):
pass
def _sample_pos(self, **kwargs):
"""Sample positive samples."""
raise NotImplementedError
def _sample_neg(self, **kwargs):
"""Sample negative samples."""
raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs):
"""Directly returns the positive and negative indices of samples.
Args:
assign_result (:obj:`AssignResult`): Assigned results
bboxes (paddle.Tensor): Bounding boxes
gt_bboxes (paddle.Tensor): Ground truth boxes
Returns:
:obj:`SamplingResult`: sampler results
"""
pos_inds = paddle.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1)
neg_inds = paddle.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1)
gt_flags = paddle.zeros([bboxes.shape[0]], dtype='int32')
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register
__all__ = ['KptL1Cost', 'OksCost', 'ClassificationCost']
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
@register
class KptL1Cost(object):
"""KptL1Cost.
this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py
Args:
weight (int | float, optional): loss_weight.
"""
def __init__(self, weight=1.0):
self.weight = weight
def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag):
"""
Args:
kpt_pred (Tensor): Predicted keypoints with normalized coordinates
(x_{i}, y_{i}), which are all in range [0, 1]. Shape
[num_query, K, 2].
gt_keypoints (Tensor): Ground truth keypoints with normalized
coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2].
valid_kpt_flag (Tensor): valid flag of ground truth keypoints.
Shape [num_gt, K].
Returns:
paddle.Tensor: kpt_cost value with weight.
"""
kpt_cost = []
for i in range(len(gt_keypoints)):
if gt_keypoints[i].size == 0:
kpt_cost.append(kpt_pred.sum() * 0)
kpt_pred_tmp = kpt_pred.clone()
valid_flag = valid_kpt_flag[i] > 0
valid_flag_expand = valid_flag.unsqueeze(0).unsqueeze(-1).expand_as(
kpt_pred_tmp)
if not valid_flag_expand.all():
kpt_pred_tmp = masked_fill(kpt_pred_tmp, ~valid_flag_expand, 0)
cost = F.pairwise_distance(
kpt_pred_tmp.reshape((kpt_pred_tmp.shape[0], -1)),
gt_keypoints[i].reshape((-1, )).unsqueeze(0),
p=1,
keepdim=True)
avg_factor = paddle.clip(
valid_flag.astype('float32').sum() * 2, 1.0)
cost = cost / avg_factor
kpt_cost.append(cost)
kpt_cost = paddle.concat(kpt_cost, axis=1)
return kpt_cost * self.weight
@register
class OksCost(object):
"""OksCost.
this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py
Args:
num_keypoints (int): number of keypoints
weight (int | float, optional): loss_weight.
"""
def __init__(self, num_keypoints=17, weight=1.0):
self.weight = weight
if num_keypoints == 17:
self.sigmas = np.array(
[
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89
],
dtype=np.float32) / 10.0
elif num_keypoints == 14:
self.sigmas = np.array(
[
.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89,
.89, .79, .79
],
dtype=np.float32) / 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
"""
Args:
kpt_pred (Tensor): Predicted keypoints with unnormalized
coordinates (x_{i}, y_{i}). Shape [num_query, K, 2].
gt_keypoints (Tensor): Ground truth keypoints with unnormalized
coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2].
valid_kpt_flag (Tensor): valid flag of ground truth keypoints.
Shape [num_gt, K].
gt_areas (Tensor): Ground truth mask areas. Shape [num_gt,].
Returns:
paddle.Tensor: oks_cost value with weight.
"""
sigmas = paddle.to_tensor(self.sigmas)
variances = (sigmas * 2)**2
oks_cost = []
assert len(gt_keypoints) == len(gt_areas)
for i in range(len(gt_keypoints)):
if gt_keypoints[i].size == 0:
oks_cost.append(kpt_pred.sum() * 0)
squared_distance = \
(kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)) ** 2 + \
(kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2
vis_flag = (valid_kpt_flag[i] > 0).astype('int')
vis_ind = vis_flag.nonzero(as_tuple=False)[:, 0]
num_vis_kpt = vis_ind.shape[0]
# assert num_vis_kpt > 0
if num_vis_kpt == 0:
oks_cost.append(paddle.zeros((squared_distance.shape[0], 1)))
continue
area = gt_areas[i]
squared_distance0 = squared_distance / (area * variances * 2)
squared_distance0 = paddle.index_select(
squared_distance0, vis_ind, axis=1)
squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1,
keepdim=True)
oks = squared_distance1 / num_vis_kpt
# The 1 is a constant that doesn't change the matching, so omitted.
oks_cost.append(-oks)
oks_cost = paddle.concat(oks_cost, axis=1)
return oks_cost * self.weight
@register
class ClassificationCost:
"""ClsSoftmaxCost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
paddle.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@register
class FocalLossCost:
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
binary_input (bool, optional): Whether the input is binary,
default False.
"""
def __init__(self,
weight=1.,
alpha=0.25,
gamma=2,
eps=1e-12,
binary_input=False):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
self.binary_input = binary_input
def _focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
paddle.Tensor: cls_cost value with weight
"""
if gt_labels.size == 0:
return cls_pred.sum() * 0
cls_pred = F.sigmoid(cls_pred)
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = paddle.index_select(
pos_cost, gt_labels, axis=1) - paddle.index_select(
neg_cost, gt_labels, axis=1)
return cls_cost * self.weight
def _mask_focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits
in shape (num_query, d1, ..., dn), dtype=paddle.float32.
gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
dtype=paddle.long. Labels should be binary.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1)
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
cls_pred = F.sigmoid(cls_pred)
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = paddle.einsum('nc,mc->nm', pos_cost, gt_labels) + \
paddle.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
return cls_cost / n * self.weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits.
gt_labels (Tensor)): Labels.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
if self.binary_input:
return self._mask_focal_loss_cost(cls_pred, gt_labels)
else:
return self._focal_loss_cost(cls_pred, gt_labels)
......@@ -285,36 +285,6 @@ class BottleNeck(nn.Layer):
# ResNeXt
width = int(ch_out * (base_width / 64.)) * groups
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2a = ConvNormLayer(
ch_in=ch_in,
ch_out=width,
......@@ -351,6 +321,36 @@ class BottleNeck(nn.Layer):
freeze_norm=freeze_norm,
lr=lr)
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.std_senet = std_senet
if self.std_senet:
self.se = SELayer(ch_out * self.expansion)
......
......@@ -284,9 +284,9 @@ class RelativePositionBias(nn.Layer):
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww
......
......@@ -67,3 +67,4 @@ from .yolof_head import *
from .ppyoloe_contrast_head import *
from .centertrack_head import *
from .sparse_roi_head import *
from .petr_head import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/dense_heads/petr_head.py
"""
import copy
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
import paddle.distributed as dist
from ..transformers.petr_transformer import inverse_sigmoid, masked_fill
from ..initializer import constant_, normal_
__all__ = ["PETRHead"]
from functools import partial
def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
res = tuple(map(list, zip(*map_results)))
return res
def reduce_mean(tensor):
""""Obtain the mean of tensor on different GPUs."""
if not (dist.get_world_size() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(
tensor.divide(
paddle.to_tensor(
dist.get_world_size(), dtype='float32')),
op=dist.ReduceOp.SUM)
return tensor
def gaussian_radius(det_size, min_overlap=0.7):
"""calculate gaussian radius according to object size.
"""
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = paddle.sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = paddle.sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = paddle.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
def gaussian2D(shape, sigma=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y = paddle.arange(-m, m + 1, dtype="float32")[:, None]
x = paddle.arange(-n, n + 1, dtype="float32")[None, :]
# y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = paddle.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(np.float32).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
gaussian = paddle.to_tensor(gaussian, dtype=heatmap.dtype)
x, y = int(center[0]), int(center[1])
radius = int(radius)
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
radius + right]
# assert masked_gaussian.equal(1).float().sum() == 1
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
heatmap[y - top:y + bottom, x - left:x + right] = paddle.maximum(
masked_heatmap, masked_gaussian * k)
return heatmap
@register
class PETRHead(nn.Layer):
"""Head of `End-to-End Multi-Person Pose Estimation with Transformers`.
Args:
num_classes (int): Number of categories excluding the background.
in_channels (int): Number of channels in the input feature map.
num_query (int): Number of query in Transformer.
num_kpt_fcs (int, optional): Number of fully-connected layers used in
`FFN`, which is then used for the keypoint regression head.
Default 2.
transformer (obj:`mmcv.ConfigDict`|dict): ConfigDict is used for
building the Encoder and Decoder. Default: None.
sync_cls_avg_factor (bool): Whether to sync the avg_factor of
all ranks. Default to False.
positional_encoding (obj:`mmcv.ConfigDict`|dict):
Config for position encoding.
loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
classification loss. Default `CrossEntropyLoss`.
loss_kpt (obj:`mmcv.ConfigDict`|dict): Config of the
regression loss. Default `L1Loss`.
loss_oks (obj:`mmcv.ConfigDict`|dict): Config of the
regression oks loss. Default `OKSLoss`.
loss_hm (obj:`mmcv.ConfigDict`|dict): Config of the
regression heatmap loss. Default `NegLoss`.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
with_kpt_refine (bool): Whether to refine the reference points
in the decoder. Defaults to True.
test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
transformer head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
__inject__ = [
"transformer", "positional_encoding", "assigner", "sampler", "loss_cls",
"loss_kpt", "loss_oks", "loss_hm", "loss_kpt_rpn", "loss_kpt_refine",
"loss_oks_refine"
]
def __init__(self,
num_classes,
in_channels,
num_query=100,
num_kpt_fcs=2,
num_keypoints=17,
transformer=None,
sync_cls_avg_factor=True,
positional_encoding='SinePositionalEncoding',
loss_cls='FocalLoss',
loss_kpt='L1Loss',
loss_oks='OKSLoss',
loss_hm='CenterFocalLoss',
with_kpt_refine=True,
assigner='PoseHungarianAssigner',
sampler='PseudoSampler',
loss_kpt_rpn='L1Loss',
loss_kpt_refine='L1Loss',
loss_oks_refine='opera.OKSLoss',
test_cfg=dict(max_per_img=100),
init_cfg=None,
**kwargs):
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
# since it brings inconvenience when the initialization of
# `AnchorFreeHead` is called.
super().__init__()
self.bg_cls_weight = 0
self.sync_cls_avg_factor = sync_cls_avg_factor
self.assigner = assigner
self.sampler = sampler
self.num_query = num_query
self.num_classes = num_classes
self.in_channels = in_channels
self.num_kpt_fcs = num_kpt_fcs
self.test_cfg = test_cfg
self.fp16_enabled = False
self.as_two_stage = transformer.as_two_stage
self.with_kpt_refine = with_kpt_refine
self.num_keypoints = num_keypoints
self.loss_cls = loss_cls
self.loss_kpt = loss_kpt
self.loss_kpt_rpn = loss_kpt_rpn
self.loss_kpt_refine = loss_kpt_refine
self.loss_oks = loss_oks
self.loss_oks_refine = loss_oks_refine
self.loss_hm = loss_hm
if self.loss_cls.use_sigmoid:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.positional_encoding = positional_encoding
self.transformer = transformer
self.embed_dims = self.transformer.embed_dims
# assert 'num_feats' in positional_encoding
num_feats = positional_encoding.num_pos_feats
assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
f' and {num_feats}.'
self._init_layers()
self.init_weights()
def _init_layers(self):
"""Initialize classification branch and keypoint branch of head."""
fc_cls = nn.Linear(self.embed_dims, self.cls_out_channels)
kpt_branch = []
kpt_branch.append(nn.Linear(self.embed_dims, 512))
kpt_branch.append(nn.ReLU())
for _ in range(self.num_kpt_fcs):
kpt_branch.append(nn.Linear(512, 512))
kpt_branch.append(nn.ReLU())
kpt_branch.append(nn.Linear(512, 2 * self.num_keypoints))
kpt_branch = nn.Sequential(*kpt_branch)
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
# last kpt_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if self.with_kpt_refine:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.kpt_branches = _get_clones(kpt_branch, num_pred)
else:
self.cls_branches = nn.LayerList([fc_cls for _ in range(num_pred)])
self.kpt_branches = nn.LayerList(
[kpt_branch for _ in range(num_pred)])
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)
refine_kpt_branch = []
for _ in range(self.num_kpt_fcs):
refine_kpt_branch.append(
nn.Linear(self.embed_dims, self.embed_dims))
refine_kpt_branch.append(nn.ReLU())
refine_kpt_branch.append(nn.Linear(self.embed_dims, 2))
refine_kpt_branch = nn.Sequential(*refine_kpt_branch)
if self.with_kpt_refine:
num_pred = self.transformer.refine_decoder.num_layers
self.refine_kpt_branches = _get_clones(refine_kpt_branch, num_pred)
self.fc_hm = nn.Linear(self.embed_dims, self.num_keypoints)
def init_weights(self):
"""Initialize weights of the PETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
constant_(m.bias, bias_init)
for m in self.kpt_branches:
constant_(m[-1].bias, 0)
# initialization of keypoint refinement branch
if self.with_kpt_refine:
for m in self.refine_kpt_branches:
constant_(m[-1].bias, 0)
# initialize bias for heatmap prediction
bias_init = bias_init_with_prob(0.1)
normal_(self.fc_hm.weight, std=0.01)
constant_(self.fc_hm.bias, bias_init)
def forward(self, mlvl_feats, img_metas):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 4D-tensor with shape
(N, C, H, W).
img_metas (list[dict]): List of image information.
Returns:
outputs_classes (Tensor): Outputs from the classification head,
shape [nb_dec, bs, num_query, cls_out_channels]. Note
cls_out_channels should include background.
outputs_kpts (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h).
Shape [nb_dec, bs, num_query, K*2].
enc_outputs_class (Tensor): The score of each point on encode
feature map, has shape (N, h*w, num_class). Only when
as_two_stage is Ture it would be returned, otherwise
`None` would be returned.
enc_outputs_kpt (Tensor): The proposal generate from the
encode feature map, has shape (N, h*w, K*2). Only when
as_two_stage is Ture it would be returned, otherwise
`None` would be returned.
"""
batch_size = mlvl_feats[0].shape[0]
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = paddle.zeros(
(batch_size, input_img_h, input_img_w), dtype=mlvl_feats[0].dtype)
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 1
mlvl_masks = []
mlvl_positional_encodings = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(
img_masks[None], size=feat.shape[-2:]).squeeze(0))
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]).transpose(
[0, 3, 1, 2]))
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_kpt, hm_proto, memory = \
self.transformer(
mlvl_feats,
mlvl_masks,
query_embeds,
mlvl_positional_encodings,
kpt_branches=self.kpt_branches \
if self.with_kpt_refine else None, # noqa:E501
cls_branches=self.cls_branches \
if self.as_two_stage else None # noqa:E501
)
outputs_classes = []
outputs_kpts = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[lvl])
tmp_kpt = self.kpt_branches[lvl](hs[lvl])
assert reference.shape[-1] == self.num_keypoints * 2
tmp_kpt += reference
outputs_kpt = F.sigmoid(tmp_kpt)
outputs_classes.append(outputs_class)
outputs_kpts.append(outputs_kpt)
outputs_classes = paddle.stack(outputs_classes)
outputs_kpts = paddle.stack(outputs_kpts)
if hm_proto is not None:
# get heatmap prediction (training phase)
hm_memory, hm_mask = hm_proto
hm_pred = self.fc_hm(hm_memory)
hm_proto = (hm_pred.transpose((0, 3, 1, 2)), hm_mask)
if self.as_two_stage:
return outputs_classes, outputs_kpts, \
enc_outputs_class, F.sigmoid(enc_outputs_kpt), \
hm_proto, memory, mlvl_masks
else:
raise RuntimeError('only "as_two_stage=True" is supported.')
def forward_refine(self, memory, mlvl_masks, refine_targets, losses,
img_metas):
"""Forward function.
Args:
mlvl_masks (tuple[Tensor]): The key_padding_mask from
different level used for encoder and decoder,
each is a 3D-tensor with shape (bs, H, W).
losses (dict[str, Tensor]): A dictionary of loss components.
img_metas (list[dict]): List of image information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
kpt_preds, kpt_targets, area_targets, kpt_weights = refine_targets
pos_inds = kpt_weights.sum(-1) > 0
if not pos_inds.any():
pos_kpt_preds = paddle.zeros_like(kpt_preds[:1])
pos_img_inds = paddle.zeros([1], dtype="int64")
else:
pos_kpt_preds = kpt_preds[pos_inds]
pos_img_inds = (pos_inds.nonzero() /
self.num_query).squeeze(1).astype("int64")
hs, init_reference, inter_references = self.transformer.forward_refine(
mlvl_masks,
memory,
pos_kpt_preds.detach(),
pos_img_inds,
kpt_branches=self.refine_kpt_branches
if self.with_kpt_refine else None, # noqa:E501
)
outputs_kpts = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
tmp_kpt = self.refine_kpt_branches[lvl](hs[lvl])
assert reference.shape[-1] == 2
tmp_kpt += reference
outputs_kpt = F.sigmoid(tmp_kpt)
outputs_kpts.append(outputs_kpt)
outputs_kpts = paddle.stack(outputs_kpts)
if not self.training:
return outputs_kpts
num_valid_kpt = paddle.clip(
reduce_mean(kpt_weights.sum()), min=1).item()
num_total_pos = paddle.to_tensor(
[outputs_kpts.shape[1]], dtype=kpt_weights.dtype)
num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item()
if not pos_inds.any():
for i, kpt_refine_preds in enumerate(outputs_kpts):
loss_kpt = loss_oks = kpt_refine_preds.sum() * 0
losses[f'd{i}.loss_kpt_refine'] = loss_kpt
losses[f'd{i}.loss_oks_refine'] = loss_oks
continue
return losses
batch_size = mlvl_masks[0].shape[0]
factors = []
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
factor = paddle.to_tensor(
[img_w, img_h, img_w, img_h],
dtype="float32").squeeze(-1).unsqueeze(0).tile(
(self.num_query, 1))
factors.append(factor)
factors = paddle.concat(factors, 0)
factors = factors[pos_inds][:, :2].tile((1, kpt_preds.shape[-1] // 2))
pos_kpt_weights = kpt_weights[pos_inds]
pos_kpt_targets = kpt_targets[pos_inds]
pos_kpt_targets_scaled = pos_kpt_targets * factors
pos_areas = area_targets[pos_inds]
pos_valid = kpt_weights[pos_inds][:, 0::2]
for i, kpt_refine_preds in enumerate(outputs_kpts):
if not pos_inds.any():
print("refine kpt and oks skip")
loss_kpt = loss_oks = kpt_refine_preds.sum() * 0
losses[f'd{i}.loss_kpt_refine'] = loss_kpt
losses[f'd{i}.loss_oks_refine'] = loss_oks
continue
# kpt L1 Loss
pos_refine_preds = kpt_refine_preds.reshape(
(kpt_refine_preds.shape[0], -1))
loss_kpt = self.loss_kpt_refine(
pos_refine_preds,
pos_kpt_targets,
pos_kpt_weights,
avg_factor=num_valid_kpt)
losses[f'd{i}.loss_kpt_refine'] = loss_kpt
# kpt oks loss
pos_refine_preds_scaled = pos_refine_preds * factors
assert (pos_areas > 0).all()
loss_oks = self.loss_oks_refine(
pos_refine_preds_scaled,
pos_kpt_targets_scaled,
pos_valid,
pos_areas,
avg_factor=num_total_pos)
losses[f'd{i}.loss_oks_refine'] = loss_oks
return losses
# over-write because img_metas are needed as inputs for bbox_head.
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_keypoints=None,
gt_areas=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""Forward function for training mode.
Args:
x (list[Tensor]): Features from backbone.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (list[Tensor]): Ground truth labels of each box,
shape (num_gts,).
gt_keypoints (list[Tensor]): Ground truth keypoints of the image,
shape (num_gts, K*3).
gt_areas (list[Tensor]): Ground truth mask areas of each box,
shape (num_gts,).
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert proposal_cfg is None, '"proposal_cfg" must be None'
outs = self(x, img_metas)
memory, mlvl_masks = outs[-2:]
outs = outs[:-2]
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, gt_keypoints, gt_areas, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_keypoints, gt_areas,
img_metas)
losses_and_targets = self.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
# losses = losses_and_targets
losses, refine_targets = losses_and_targets
# get pose refinement loss
losses = self.forward_refine(memory, mlvl_masks, refine_targets, losses,
img_metas)
return losses
def loss(self,
all_cls_scores,
all_kpt_preds,
enc_cls_scores,
enc_kpt_preds,
enc_hm_proto,
gt_bboxes_list,
gt_labels_list,
gt_keypoints_list,
gt_areas_list,
img_metas,
gt_bboxes_ignore=None):
"""Loss function.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_kpt_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (x_{i}, y_{i}) and shape
[nb_dec, bs, num_query, K*2].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map, has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_kpt_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, K*2). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_keypoints_list (list[Tensor]): Ground truth keypoints for each
image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v,
..., p^{K}_x, p^{K}_y, p^{K}_v] format.
gt_areas_list (list[Tensor]): Ground truth mask areas for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
num_dec_layers = len(all_cls_scores)
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_keypoints_list = [
gt_keypoints_list for _ in range(num_dec_layers)
]
all_gt_areas_list = [gt_areas_list for _ in range(num_dec_layers)]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_kpt, losses_oks, kpt_preds_list, kpt_targets_list, \
area_targets_list, kpt_weights_list = multi_apply(
self.loss_single, all_cls_scores, all_kpt_preds,
all_gt_labels_list, all_gt_keypoints_list,
all_gt_areas_list, img_metas_list)
loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
paddle.zeros_like(gt_labels_list[i])
for i in range(len(img_metas))
]
enc_loss_cls, enc_losses_kpt = \
self.loss_single_rpn(
enc_cls_scores, enc_kpt_preds, binary_labels_list,
gt_keypoints_list, gt_areas_list, img_metas)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_kpt'] = enc_losses_kpt
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_kpt'] = losses_kpt[-1]
loss_dict['loss_oks'] = losses_oks[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_kpt_i, loss_oks_i in zip(
losses_cls[:-1], losses_kpt[:-1], losses_oks[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_kpt'] = loss_kpt_i
loss_dict[f'd{num_dec_layer}.loss_oks'] = loss_oks_i
num_dec_layer += 1
# losses of heatmap generated from P3 feature map
hm_pred, hm_mask = enc_hm_proto
loss_hm = self.loss_heatmap(hm_pred, hm_mask, gt_keypoints_list,
gt_labels_list, gt_bboxes_list)
loss_dict['loss_hm'] = loss_hm
return loss_dict, (kpt_preds_list[-1], kpt_targets_list[-1],
area_targets_list[-1], kpt_weights_list[-1])
def loss_heatmap(self, hm_pred, hm_mask, gt_keypoints, gt_labels,
gt_bboxes):
assert hm_pred.shape[-2:] == hm_mask.shape[-2:]
num_img, _, h, w = hm_pred.shape
# placeholder of heatmap target (Gaussian distribution)
hm_target = paddle.zeros(hm_pred.shape, hm_pred.dtype)
for i, (gt_label, gt_bbox, gt_keypoint
) in enumerate(zip(gt_labels, gt_bboxes, gt_keypoints)):
if gt_label.shape[0] == 0:
continue
gt_keypoint = gt_keypoint.reshape((gt_keypoint.shape[0], -1,
3)).clone()
gt_keypoint[..., :2] /= 8
assert gt_keypoint[..., 0].max() <= w + 0.5 # new coordinate system
assert gt_keypoint[..., 1].max() <= h + 0.5 # new coordinate system
gt_bbox /= 8
gt_w = gt_bbox[:, 2] - gt_bbox[:, 0]
gt_h = gt_bbox[:, 3] - gt_bbox[:, 1]
for j in range(gt_label.shape[0]):
# get heatmap radius
kp_radius = paddle.clip(
paddle.floor(
gaussian_radius(
(gt_h[j], gt_w[j]), min_overlap=0.9)),
min=0,
max=3)
for k in range(self.num_keypoints):
if gt_keypoint[j, k, 2] > 0:
gt_kp = gt_keypoint[j, k, :2]
gt_kp_int = paddle.floor(gt_kp)
hm_target[i, k] = draw_umich_gaussian(
hm_target[i, k], gt_kp_int, kp_radius)
# compute heatmap loss
hm_pred = paddle.clip(
F.sigmoid(hm_pred), min=1e-4, max=1 - 1e-4) # refer to CenterNet
loss_hm = self.loss_hm(
hm_pred,
hm_target.detach(),
mask=~hm_mask.astype("bool").unsqueeze(1))
return loss_hm
def loss_single(self, cls_scores, kpt_preds, gt_labels_list,
gt_keypoints_list, gt_areas_list, img_metas):
"""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
kpt_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (x_{i}, y_{i}) and
shape [bs, num_query, K*2].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_keypoints_list (list[Tensor]): Ground truth keypoints for each
image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v,
..., p^{K}_x, p^{K}_y, p^{K}_v] format.
gt_areas_list (list[Tensor]): Ground truth mask areas for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.shape[0]
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
kpt_preds_list = [kpt_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list,
gt_labels_list, gt_keypoints_list,
gt_areas_list, img_metas)
(labels_list, label_weights_list, kpt_targets_list, kpt_weights_list,
area_targets_list, num_total_pos, num_total_neg) = cls_reg_targets
labels = paddle.concat(labels_list, 0)
label_weights = paddle.concat(label_weights_list, 0)
kpt_targets = paddle.concat(kpt_targets_list, 0)
kpt_weights = paddle.concat(kpt_weights_list, 0)
area_targets = paddle.concat(area_targets_list, 0)
# classification loss
cls_scores = cls_scores.reshape((-1, self.cls_out_channels))
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
paddle.to_tensor(
[cls_avg_factor], dtype=cls_scores.dtype))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt keypoints accross all gpus, for
# normalization purposes
num_total_pos = paddle.to_tensor([num_total_pos], dtype=loss_cls.dtype)
num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item()
# construct factors used for rescale keypoints
factors = []
for img_meta, kpt_pred in zip(img_metas, kpt_preds):
img_h, img_w, _ = img_meta['img_shape']
factor = paddle.to_tensor(
[img_w, img_h, img_w, img_h],
dtype=kpt_pred.dtype).squeeze().unsqueeze(0).tile(
(kpt_pred.shape[0], 1))
factors.append(factor)
factors = paddle.concat(factors, 0)
# keypoint regression loss
kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1]))
num_valid_kpt = paddle.clip(
reduce_mean(kpt_weights.sum()), min=1).item()
# assert num_valid_kpt == (kpt_targets>0).sum().item()
loss_kpt = self.loss_kpt(
kpt_preds,
kpt_targets.detach(),
kpt_weights.detach(),
avg_factor=num_valid_kpt)
# keypoint oks loss
pos_inds = kpt_weights.sum(-1) > 0
if not pos_inds.any():
loss_oks = kpt_preds.sum() * 0
else:
factors = factors[pos_inds][:, :2].tile((
(1, kpt_preds.shape[-1] // 2)))
pos_kpt_preds = kpt_preds[pos_inds] * factors
pos_kpt_targets = kpt_targets[pos_inds] * factors
pos_areas = area_targets[pos_inds]
pos_valid = kpt_weights[pos_inds][..., 0::2]
assert (pos_areas > 0).all()
loss_oks = self.loss_oks(
pos_kpt_preds,
pos_kpt_targets,
pos_valid,
pos_areas,
avg_factor=num_total_pos)
return loss_cls, loss_kpt, loss_oks, kpt_preds, kpt_targets, \
area_targets, kpt_weights
def get_targets(self, cls_scores_list, kpt_preds_list, gt_labels_list,
gt_keypoints_list, gt_areas_list, img_metas):
"""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
kpt_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(x_{i}, y_{i}) and shape [num_query, K*2].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_keypoints_list (list[Tensor]): Ground truth keypoints for each
image with shape (num_gts, K*3).
gt_areas_list (list[Tensor]): Ground truth mask areas for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all
images.
- kpt_targets_list (list[Tensor]): Keypoint targets for all
images.
- kpt_weights_list (list[Tensor]): Keypoint weights for all
images.
- area_targets_list (list[Tensor]): area targets for all
images.
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
"""
(labels_list, label_weights_list, kpt_targets_list, kpt_weights_list,
area_targets_list, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, kpt_preds_list,
gt_labels_list, gt_keypoints_list, gt_areas_list, img_metas)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, kpt_targets_list,
kpt_weights_list, area_targets_list, num_total_pos,
num_total_neg)
def _get_target_single(self, cls_score, kpt_pred, gt_labels, gt_keypoints,
gt_areas, img_meta):
"""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
kpt_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (x_{i}, y_{i}) and
shape [num_query, K*2].
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_keypoints (Tensor): Ground truth keypoints for one image with
shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v, ..., \
p^{K}_x, p^{K}_y, p^{K}_v] format.
gt_areas (Tensor): Ground truth mask areas for one image
with shape (num_gts, ).
img_meta (dict): Meta information for one image.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor): Label weights of each image.
- kpt_targets (Tensor): Keypoint targets of each image.
- kpt_weights (Tensor): Keypoint weights of each image.
- area_targets (Tensor): Area targets of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_bboxes = kpt_pred.shape[0]
# assigner and sampler
assign_result = self.assigner.assign(cls_score, kpt_pred, gt_labels,
gt_keypoints, gt_areas, img_meta)
sampling_result = self.sampler.sample(assign_result, kpt_pred,
gt_keypoints)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label targets
labels = paddle.full((num_bboxes, ), self.num_classes, dtype="int64")
label_weights = paddle.ones((num_bboxes, ), dtype=gt_labels.dtype)
kpt_targets = paddle.zeros_like(kpt_pred)
kpt_weights = paddle.zeros_like(kpt_pred)
area_targets = paddle.zeros((kpt_pred.shape[0], ), dtype=kpt_pred.dtype)
if pos_inds.size == 0:
return (labels, label_weights, kpt_targets, kpt_weights,
area_targets, pos_inds, neg_inds)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds][
..., 0].astype("int64")
img_h, img_w, _ = img_meta['img_shape']
# keypoint targets
pos_gt_kpts = gt_keypoints[sampling_result.pos_assigned_gt_inds]
pos_gt_kpts = pos_gt_kpts.reshape(
(len(sampling_result.pos_assigned_gt_inds), -1, 3))
valid_idx = pos_gt_kpts[:, :, 2] > 0
pos_kpt_weights = kpt_weights[pos_inds].reshape(
(pos_gt_kpts.shape[0], kpt_weights.shape[-1] // 2, 2))
# pos_kpt_weights[valid_idx][...] = 1.0
pos_kpt_weights = masked_fill(pos_kpt_weights,
valid_idx.unsqueeze(-1), 1.0)
kpt_weights[pos_inds] = pos_kpt_weights.reshape(
(pos_kpt_weights.shape[0], kpt_pred.shape[-1]))
factor = paddle.to_tensor(
[img_w, img_h], dtype=kpt_pred.dtype).squeeze().unsqueeze(0)
pos_gt_kpts_normalized = pos_gt_kpts[..., :2]
pos_gt_kpts_normalized[..., 0] = pos_gt_kpts_normalized[..., 0] / \
factor[:, 0:1]
pos_gt_kpts_normalized[..., 1] = pos_gt_kpts_normalized[..., 1] / \
factor[:, 1:2]
kpt_targets[pos_inds] = pos_gt_kpts_normalized.reshape(
(pos_gt_kpts.shape[0], kpt_pred.shape[-1]))
pos_gt_areas = gt_areas[sampling_result.pos_assigned_gt_inds][..., 0]
area_targets[pos_inds] = pos_gt_areas
return (labels, label_weights, kpt_targets, kpt_weights, area_targets,
pos_inds, neg_inds)
def loss_single_rpn(self, cls_scores, kpt_preds, gt_labels_list,
gt_keypoints_list, gt_areas_list, img_metas):
"""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
kpt_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (x_{i}, y_{i}) and
shape [bs, num_query, K*2].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_keypoints_list (list[Tensor]): Ground truth keypoints for each
image with shape (num_gts, K*3) in [p^{1}_x, p^{1}_y, p^{1}_v,
..., p^{K}_x, p^{K}_y, p^{K}_v] format.
gt_areas_list (list[Tensor]): Ground truth mask areas for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.shape[0]
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
kpt_preds_list = [kpt_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list,
gt_labels_list, gt_keypoints_list,
gt_areas_list, img_metas)
(labels_list, label_weights_list, kpt_targets_list, kpt_weights_list,
area_targets_list, num_total_pos, num_total_neg) = cls_reg_targets
labels = paddle.concat(labels_list, 0)
label_weights = paddle.concat(label_weights_list, 0)
kpt_targets = paddle.concat(kpt_targets_list, 0)
kpt_weights = paddle.concat(kpt_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape((-1, self.cls_out_channels))
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
paddle.to_tensor(
[cls_avg_factor], dtype=cls_scores.dtype))
cls_avg_factor = max(cls_avg_factor, 1)
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt keypoints accross all gpus, for
# normalization purposes
# num_total_pos = loss_cls.to_tensor([num_total_pos])
# num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item()
# keypoint regression loss
kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1]))
num_valid_kpt = paddle.clip(
reduce_mean(kpt_weights.sum()), min=1).item()
# assert num_valid_kpt == (kpt_targets>0).sum().item()
loss_kpt = self.loss_kpt_rpn(
kpt_preds, kpt_targets, kpt_weights, avg_factor=num_valid_kpt)
return loss_cls, loss_kpt
def get_bboxes(self,
all_cls_scores,
all_kpt_preds,
enc_cls_scores,
enc_kpt_preds,
hm_proto,
memory,
mlvl_masks,
img_metas,
rescale=False):
"""Transform network outputs for a batch into bbox predictions.
Args:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_kpt_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (x_{i}, y_{i}) and shape
[nb_dec, bs, num_query, K*2].
enc_cls_scores (Tensor): Classification scores of points on
encode feature map, has shape (N, h*w, num_classes).
Only be passed when as_two_stage is True, otherwise is None.
enc_kpt_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, K*2). Only be
passed when as_two_stage is True, otherwise is None.
img_metas (list[dict]): Meta information of each image.
rescale (bool, optional): If True, return boxes in original
image space. Defalut False.
Returns:
list[list[Tensor, Tensor]]: Each item in result_list is 3-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the predicted class label of
the corresponding box. The third item is an (n, K, 3) tensor
with [p^{1}_x, p^{1}_y, p^{1}_v, ..., p^{K}_x, p^{K}_y,
p^{K}_v] format.
"""
cls_scores = all_cls_scores[-1]
kpt_preds = all_kpt_preds[-1]
result_list = []
for img_id in range(len(img_metas)):
cls_score = cls_scores[img_id]
kpt_pred = kpt_preds[img_id]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
# TODO: only support single image test
# memory_i = memory[:, img_id, :]
# mlvl_mask = mlvl_masks[img_id]
proposals = self._get_bboxes_single(cls_score, kpt_pred, img_shape,
scale_factor, memory,
mlvl_masks, rescale)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_score,
kpt_pred,
img_shape,
scale_factor,
memory,
mlvl_masks,
rescale=False):
"""Transform outputs from the last decoder layer into bbox predictions
for each image.
Args:
cls_score (Tensor): Box score logits from the last decoder layer
for each image. Shape [num_query, cls_out_channels].
kpt_pred (Tensor): Sigmoid outputs from the last decoder layer
for each image, with coordinate format (x_{i}, y_{i}) and
shape [num_query, K*2].
img_shape (tuple[int]): Shape of input image, (height, width, 3).
scale_factor (ndarray, optional): Scale factor of the image arange
as (w_scale, h_scale, w_scale, h_scale).
rescale (bool, optional): If True, return boxes in original image
space. Default False.
Returns:
tuple[Tensor]: Results of detected bboxes and labels.
- det_bboxes: Predicted bboxes with shape [num_query, 5],
where the first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column are scores
between 0 and 1.
- det_labels: Predicted labels of the corresponding box with
shape [num_query].
- det_kpts: Predicted keypoints with shape [num_query, K, 3].
"""
assert len(cls_score) == len(kpt_pred)
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
# exclude background
if self.loss_cls.use_sigmoid:
cls_score = F.sigmoid(cls_score)
scores, indexs = cls_score.reshape([-1]).topk(max_per_img)
det_labels = indexs % self.num_classes
bbox_index = indexs // self.num_classes
kpt_pred = kpt_pred[bbox_index]
else:
scores, det_labels = F.softmax(cls_score, axis=-1)[..., :-1].max(-1)
scores, bbox_index = scores.topk(max_per_img)
kpt_pred = kpt_pred[bbox_index]
det_labels = det_labels[bbox_index]
# ----- results after pose decoder -----
# det_kpts = kpt_pred.reshape((kpt_pred.shape[0], -1, 2))
# ----- results after joint decoder (default) -----
# import time
# start = time.time()
refine_targets = (kpt_pred, None, None, paddle.ones_like(kpt_pred))
refine_outputs = self.forward_refine(memory, mlvl_masks, refine_targets,
None, None)
# end = time.time()
# print(f'refine time: {end - start:.6f}')
det_kpts = refine_outputs[-1]
det_kpts[..., 0] = det_kpts[..., 0] * img_shape[1]
det_kpts[..., 1] = det_kpts[..., 1] * img_shape[0]
det_kpts[..., 0].clip_(min=0, max=img_shape[1])
det_kpts[..., 1].clip_(min=0, max=img_shape[0])
if rescale:
det_kpts /= paddle.to_tensor(
scale_factor[:2],
dtype=det_kpts.dtype).unsqueeze(0).unsqueeze(0)
# use circumscribed rectangle box of keypoints as det bboxes
x1 = det_kpts[..., 0].min(axis=1, keepdim=True)
y1 = det_kpts[..., 1].min(axis=1, keepdim=True)
x2 = det_kpts[..., 0].max(axis=1, keepdim=True)
y2 = det_kpts[..., 1].max(axis=1, keepdim=True)
det_bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
det_bboxes = paddle.concat((det_bboxes, scores.unsqueeze(1)), -1)
det_kpts = paddle.concat(
(det_kpts, paddle.ones(
det_kpts[..., :1].shape, dtype=det_kpts.dtype)),
axis=2)
return det_bboxes, det_labels, det_kpts
def simple_test(self, feats, img_metas, rescale=False):
"""Test det bboxes without test-time augmentation.
Args:
feats (tuple[paddle.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is
3-tuple. The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,). The third item is ``kpts`` with shape
(n, K, 3), in [p^{1}_x, p^{1}_y, p^{1}_v, p^{K}_x, p^{K}_y,
p^{K}_v] format.
"""
# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
return results_list
def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes):
return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes)
......@@ -1135,7 +1135,7 @@ def _convert_attention_mask(attn_mask, dtype):
"""
return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
@register
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
......
......@@ -21,7 +21,7 @@ import paddle.nn.functional as F
import paddle.nn as nn
from ppdet.core.workspace import register
__all__ = ['FocalLoss']
__all__ = ['FocalLoss', 'Weighted_FocalLoss']
@register
class FocalLoss(nn.Layer):
......@@ -59,3 +59,80 @@ class FocalLoss(nn.Layer):
pred, target, alpha=self.alpha, gamma=self.gamma,
reduction=reduction)
return loss * self.loss_weight
@register
class Weighted_FocalLoss(FocalLoss):
"""A wrapper around paddle.nn.functional.sigmoid_focal_loss.
Args:
use_sigmoid (bool): currently only support use_sigmoid=True
alpha (float): parameter alpha in Focal Loss
gamma (float): parameter gamma in Focal Loss
loss_weight (float): final loss will be multiplied by this
"""
def __init__(self,
use_sigmoid=True,
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
reduction="mean"):
super(FocalLoss, self).__init__()
assert use_sigmoid == True, \
'Focal Loss only supports sigmoid at the moment'
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None):
"""forward function.
Args:
pred (Tensor): logits of class prediction, of shape (N, num_classes)
target (Tensor): target class label, of shape (N, )
reduction (str): the way to reduce loss, one of (none, sum, mean)
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
num_classes = pred.shape[1]
target = F.one_hot(target, num_classes + 1).astype(pred.dtype)
target = target[:, :-1].detach()
loss = F.sigmoid_focal_loss(
pred, target, alpha=self.alpha, gamma=self.gamma,
reduction='none')
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss * self.loss_weight
......@@ -18,12 +18,13 @@ from __future__ import print_function
from itertools import cycle, islice
from collections import abc
import numpy as np
import paddle
import paddle.nn as nn
from ppdet.core.workspace import register, serializable
__all__ = ['HrHRNetLoss', 'KeyPointMSELoss']
__all__ = ['HrHRNetLoss', 'KeyPointMSELoss', 'OKSLoss', 'CenterFocalLoss', 'L1Loss']
@register
......@@ -226,3 +227,406 @@ def recursive_sum(inputs):
if isinstance(inputs, abc.Sequence):
return sum([recursive_sum(x) for x in inputs])
return inputs
def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas):
if not kpt_gts.astype('bool').any():
return kpt_preds.sum()*0
sigmas = paddle.to_tensor(sigmas, dtype=kpt_preds.dtype)
variances = (sigmas * 2)**2
assert kpt_preds.shape[0] == kpt_gts.shape[0]
kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1] // 2, 2))
kpt_gts = kpt_gts.reshape((-1, kpt_gts.shape[-1] // 2, 2))
squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \
(kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2
assert (kpt_valids.sum(-1) > 0).all()
squared_distance0 = squared_distance / (
kpt_areas[:, None] * variances[None, :] * 2)
squared_distance1 = paddle.exp(-squared_distance0)
squared_distance1 = squared_distance1 * kpt_valids
oks = squared_distance1.sum(axis=1) / kpt_valids.sum(axis=1)
return oks
def oks_loss(pred,
target,
weight,
valid=None,
area=None,
linear=False,
sigmas=None,
eps=1e-6,
avg_factor=None,
reduction=None):
"""Oks loss.
Computing the oks loss between a set of predicted poses and target poses.
The loss is calculated as negative log of oks.
Args:
pred (Tensor): Predicted poses of format (x1, y1, x2, y2, ...),
shape (n, K*2).
target (Tensor): Corresponding gt poses, shape (n, K*2).
linear (bool, optional): If True, use linear scale of loss instead of
log scale. Default: False.
eps (float): Eps to avoid log(0).
Returns:
Tensor: Loss tensor.
"""
oks = oks_overlaps(pred, target, valid, area, sigmas).clip(min=eps)
if linear:
loss = 1 - oks
else:
loss = -oks.log()
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class OKSLoss(nn.Layer):
"""OKSLoss.
Computing the oks loss between a set of predicted poses and target poses.
Args:
linear (bool): If True, use linear scale of loss instead of log scale.
Default: False.
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
"""
def __init__(self,
linear=False,
num_keypoints=17,
eps=1e-6,
reduction='mean',
loss_weight=1.0):
super(OKSLoss, self).__init__()
self.linear = linear
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
if num_keypoints == 17:
self.sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89
], dtype=np.float32) / 10.0
elif num_keypoints == 14:
self.sigmas = np.array([
.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89,
.79, .79
]) / 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def forward(self,
pred,
target,
valid,
area,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
valid (Tensor): The visible flag of the target pose.
area (Tensor): The area of the target pose.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not paddle.any(weight > 0)) and (
reduction != 'none'):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# iou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * oks_loss(
pred,
target,
weight,
valid=valid,
area=area,
linear=self.linear,
sigmas=self.sigmas,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
def center_focal_loss(pred, gt, weight=None, mask=None, avg_factor=None, reduction=None):
"""Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory.
Args:
pred (Tensor): The prediction with shape [bs, c, h, w].
gt (Tensor): The learning target of the prediction in gaussian
distribution, with shape [bs, c, h, w].
mask (Tensor): The valid mask. Defaults to None.
"""
if not gt.astype('bool').any():
return pred.sum()*0
pos_inds = gt.equal(1).astype('float32')
if mask is None:
neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32')
else:
neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32') * mask.equal(0).astype('float32')
neg_weights = paddle.pow(1 - gt, 4)
loss = 0
pos_loss = paddle.log(pred) * paddle.pow(1 - pred, 2) * pos_inds
neg_loss = paddle.log(1 - pred) * paddle.pow(pred, 2) * neg_weights * \
neg_inds
num_pos = pos_inds.astype('float32').sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class CenterFocalLoss(nn.Layer):
"""CenterFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Args:
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self,
reduction='none',
loss_weight=1.0):
super(CenterFocalLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
mask=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction in gaussian
distribution.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
mask (Tensor): The valid mask. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_reg = self.loss_weight * center_focal_loss(
pred,
target,
weight,
mask=mask,
reduction=reduction,
avg_factor=avg_factor)
return loss_reg
def l1_loss(pred, target, weight=None, reduction='mean', avg_factor=None):
"""L1 loss.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
Returns:
Tensor: Calculated loss
"""
if not target.astype('bool').any():
return pred.sum() * 0
assert pred.shape == target.shape
loss = paddle.abs(pred - target)
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class L1Loss(nn.Layer):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
......@@ -36,3 +36,4 @@ from .es_pan import *
from .lc_pan import *
from .custom_pan import *
from .dilated_encoder import *
from .channel_mapper import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on mmdet: git@github.com:open-mmlab/mmdetection.git
"""
import paddle.nn as nn
from ppdet.core.workspace import register, serializable
from ..backbones.hrnet import ConvNormLayer
from ..shape_spec import ShapeSpec
from ..initializer import xavier_uniform_, constant_
__all__ = ['ChannelMapper']
@register
@serializable
class ChannelMapper(nn.Layer):
"""Channel Mapper to reduce/increase channels of backbone features.
This is used to reduce/increase channels of backbone features.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
kernel_size (int, optional): kernel_size for reducing channels (used
at each scale). Default: 3.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU').
num_outs (int, optional): Number of output feature maps. There
would be extra_convs when num_outs larger than the length
of in_channels.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
norm_type="gn",
norm_groups=32,
act='relu',
num_outs=None,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(ChannelMapper, self).__init__()
assert isinstance(in_channels, list)
self.extra_convs = None
if num_outs is None:
num_outs = len(in_channels)
self.convs = nn.LayerList()
for in_channel in in_channels:
self.convs.append(
ConvNormLayer(
ch_in=in_channel,
ch_out=out_channels,
filter_size=kernel_size,
norm_type='gn',
norm_groups=32,
act=act))
if num_outs > len(in_channels):
self.extra_convs = nn.LayerList()
for i in range(len(in_channels), num_outs):
if i == len(in_channels):
in_channel = in_channels[-1]
else:
in_channel = out_channels
self.extra_convs.append(
ConvNormLayer(
ch_in=in_channel,
ch_out=out_channels,
filter_size=3,
stride=2,
norm_type='gn',
norm_groups=32,
act=act))
self.init_weights()
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.convs)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
if self.extra_convs:
for i in range(len(self.extra_convs)):
if i == 0:
outs.append(self.extra_convs[0](inputs[-1]))
else:
outs.append(self.extra_convs[i](outs[-1]))
return tuple(outs)
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channel, stride=1. / s)
for s in self.spatial_scales
]
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.rank() > 1:
xavier_uniform_(p)
if hasattr(p, 'bias') and p.bias is not None:
constant_(p.bais)
......@@ -25,3 +25,4 @@ from .matchers import *
from .position_encoding import *
from .deformable_transformer import *
from .dino_transformer import *
from .petr_transformer import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/utils/transformer.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register
from ..layers import MultiHeadAttention, _convert_attention_mask
from .utils import _get_clones
from ..initializer import linear_init_, normal_, constant_, xavier_uniform_
__all__ = [
'PETRTransformer', 'MultiScaleDeformablePoseAttention',
'PETR_TransformerDecoderLayer', 'PETR_TransformerDecoder',
'PETR_DeformableDetrTransformerDecoder',
'PETR_DeformableTransformerDecoder', 'TransformerEncoderLayer',
'TransformerEncoder', 'MSDeformableAttention'
]
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clip(min=0, max=1)
x1 = x.clip(min=eps)
x2 = (1 - x).clip(min=eps)
return paddle.log(x1 / x2)
@register
class TransformerEncoderLayer(nn.Layer):
__inject__ = ['attn']
def __init__(self,
d_model,
attn=None,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False):
super(TransformerEncoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.embed_dims = d_model
if attn is None:
self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
else:
self.self_attn = attn
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None, **kwargs):
residual = src
if self.normalize_before:
src = self.norm1(src)
q = k = self.with_pos_embed(src, pos_embed)
src = self.self_attn(q, k, value=src, attn_mask=src_mask, **kwargs)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
@register
class TransformerEncoder(nn.Layer):
__inject__ = ['encoder_layer']
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.embed_dims = encoder_layer.embed_dims
def forward(self, src, src_mask=None, pos_embed=None, **kwargs):
output = src
for layer in self.layers:
output = layer(
output, src_mask=src_mask, pos_embed=pos_embed, **kwargs)
if self.norm is not None:
output = self.norm(output)
return output
@register
class MSDeformableAttention(nn.Layer):
def __init__(self,
embed_dim=256,
num_heads=8,
num_levels=4,
num_points=4,
lr_mult=0.1):
"""
Multi-Scale Deformable Attention Module
"""
super(MSDeformableAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_levels = num_levels
self.num_points = num_points
self.total_points = num_heads * num_levels * num_points
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.sampling_offsets = nn.Linear(
embed_dim,
self.total_points * 2,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.attention_weights = nn.Linear(embed_dim, self.total_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
try:
# use cuda op
print("use deformable_detr_ops in ms_deformable_attn")
from deformable_detr_ops import ms_deformable_attn
except:
# use paddle func
from .utils import deformable_attention_core_func as ms_deformable_attn
self.ms_deformable_attn_core = ms_deformable_attn
self._reset_parameters()
def _reset_parameters(self):
# sampling_offsets
constant_(self.sampling_offsets.weight)
thetas = paddle.arange(
self.num_heads,
dtype=paddle.float32) * (2.0 * math.pi / self.num_heads)
grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1)
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)
grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile(
[1, self.num_levels, self.num_points, 1])
scaling = paddle.arange(
1, self.num_points + 1,
dtype=paddle.float32).reshape([1, 1, -1, 1])
grid_init *= scaling
self.sampling_offsets.bias.set_value(grid_init.flatten())
# attention_weights
constant_(self.attention_weights.weight)
constant_(self.attention_weights.bias)
# proj
xavier_uniform_(self.value_proj.weight)
constant_(self.value_proj.bias)
xavier_uniform_(self.output_proj.weight)
constant_(self.output_proj.bias)
def forward(self,
query,
key,
value,
reference_points,
value_spatial_shapes,
value_level_start_index,
attn_mask=None,
**kwargs):
"""
Args:
query (Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
attn_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, Len_q = query.shape[:2]
Len_v = value.shape[1]
assert int(value_spatial_shapes.prod(1).sum()) == Len_v
value = self.value_proj(value)
if attn_mask is not None:
attn_mask = attn_mask.astype(value.dtype).unsqueeze(-1)
value *= attn_mask
value = value.reshape([bs, Len_v, self.num_heads, self.head_dim])
sampling_offsets = self.sampling_offsets(query).reshape(
[bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
attention_weights = self.attention_weights(query).reshape(
[bs, Len_q, self.num_heads, self.num_levels * self.num_points])
attention_weights = F.softmax(attention_weights).reshape(
[bs, Len_q, self.num_heads, self.num_levels, self.num_points])
if reference_points.shape[-1] == 2:
offset_normalizer = value_spatial_shapes.flip([1]).reshape(
[1, 1, 1, self.num_levels, 1, 2])
sampling_locations = reference_points.reshape([
bs, Len_q, 1, self.num_levels, 1, 2
]) + sampling_offsets / offset_normalizer
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2] + sampling_offsets /
self.num_points * reference_points[:, :, None, :, None, 2:] *
0.5)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(
value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights)
output = self.output_proj(output)
return output
@register
class MultiScaleDeformablePoseAttention(nn.Layer):
"""An attention module used in PETR. `End-to-End Multi-Person
Pose Estimation with Transformers`.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 8.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 17.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0.1.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=17,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None,
batch_first=False,
lr_mult=0.1):
super().__init__()
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn("You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims,
num_heads * num_levels * num_points * 2,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
try:
# use cuda op
from deformable_detr_ops import ms_deformable_attn
except:
# use paddle func
from .utils import deformable_attention_core_func as ms_deformable_attn
self.ms_deformable_attn_core = ms_deformable_attn
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_(self.sampling_offsets.weight)
constant_(self.sampling_offsets.bias)
constant_(self.attention_weights.weight)
constant_(self.attention_weights.bias)
xavier_uniform_(self.value_proj.weight)
constant_(self.value_proj.bias)
xavier_uniform_(self.output_proj.weight)
constant_(self.output_proj.bias)
def forward(self,
query,
key,
value,
residual=None,
attn_mask=None,
reference_points=None,
value_spatial_shapes=None,
value_level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape (num_key, bs, embed_dims).
value (Tensor): The value tensor with shape
(num_key, bs, embed_dims).
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
reference_points (Tensor): The normalized reference points with
shape (bs, num_query, num_levels, K*2), all elements is range
in [0, 1], top-left (0,0), bottom-right (1, 1), including
padding area.
attn_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
value_spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
value_level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (value_spatial_shapes[:, 0].numpy() *
value_spatial_shapes[:, 1].numpy()).sum() == num_key
value = self.value_proj(value)
if attn_mask is not None:
# value = value.masked_fill(attn_mask[..., None], 0.0)
value *= attn_mask.unsqueeze(-1)
value = value.reshape([bs, num_key, self.num_heads, -1])
sampling_offsets = self.sampling_offsets(query).reshape([
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
])
attention_weights = self.attention_weights(query).reshape(
[bs, num_query, self.num_heads, self.num_levels * self.num_points])
attention_weights = F.softmax(attention_weights, axis=-1)
attention_weights = attention_weights.reshape(
[bs, num_query, self.num_heads, self.num_levels, self.num_points])
if reference_points.shape[-1] == self.num_points * 2:
reference_points_reshape = reference_points.reshape(
(bs, num_query, self.num_levels, -1, 2)).unsqueeze(2)
x1 = reference_points[:, :, :, 0::2].min(axis=-1, keepdim=True)
y1 = reference_points[:, :, :, 1::2].min(axis=-1, keepdim=True)
x2 = reference_points[:, :, :, 0::2].max(axis=-1, keepdim=True)
y2 = reference_points[:, :, :, 1::2].max(axis=-1, keepdim=True)
w = paddle.clip(x2 - x1, min=1e-4)
h = paddle.clip(y2 - y1, min=1e-4)
wh = paddle.concat([w, h], axis=-1)[:, :, None, :, None, :]
sampling_locations = reference_points_reshape \
+ sampling_offsets * wh * 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2K, but get {reference_points.shape[-1]} instead.')
output = self.ms_deformable_attn_core(
value, value_spatial_shapes, value_level_start_index,
sampling_locations, attention_weights)
output = self.output_proj(output)
return output
@register
class PETR_TransformerDecoderLayer(nn.Layer):
__inject__ = ['self_attn', 'cross_attn']
def __init__(self,
d_model,
nhead=8,
self_attn=None,
cross_attn=None,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False):
super(PETR_TransformerDecoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
if self_attn is None:
self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
else:
self.self_attn = self_attn
if cross_attn is None:
self.cross_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
else:
self.cross_attn = cross_attn
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout3 = nn.Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
pos_embed=None,
query_pos_embed=None,
**kwargs):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
q = k = self.with_pos_embed(tgt, query_pos_embed)
tgt = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
q = self.with_pos_embed(tgt, query_pos_embed)
key_tmp = tgt
# k = self.with_pos_embed(memory, pos_embed)
tgt = self.cross_attn(
q, key=key_tmp, value=memory, attn_mask=memory_mask, **kwargs)
tgt = residual + self.dropout2(tgt)
if not self.normalize_before:
tgt = self.norm2(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm3(tgt)
tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout3(tgt)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt
@register
class PETR_TransformerDecoder(nn.Layer):
"""Implements the decoder in PETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
__inject__ = ['decoder_layer']
def __init__(self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
num_keypoints=17,
**kwargs):
super(PETR_TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
self.num_keypoints = num_keypoints
def forward(self,
query,
*args,
reference_points=None,
valid_ratios=None,
kpt_branches=None,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape (num_query, bs, embed_dims).
reference_points (Tensor): The reference points of offset,
has shape (bs, num_query, K*2).
valid_ratios (Tensor): The radios of valid points on the feature
map, has shape (bs, num_levels, 2).
kpt_branches: (obj:`nn.LayerList`): Used for refining the
regression results. Only would be passed when `with_box_refine`
is True, otherwise would be passed a `None`.
Returns:
tuple (Tensor): Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims] and
[num_layers, bs, num_query, K*2].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == self.num_keypoints * 2:
reference_points_input = \
reference_points[:, :, None] * \
valid_ratios.tile((1, 1, self.num_keypoints))[:, None]
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * \
valid_ratios[:, None]
output = layer(
output,
*args,
reference_points=reference_points_input,
**kwargs)
if kpt_branches is not None:
tmp = kpt_branches[lid](output)
if reference_points.shape[-1] == self.num_keypoints * 2:
new_reference_points = tmp + inverse_sigmoid(
reference_points)
new_reference_points = F.sigmoid(new_reference_points)
else:
raise NotImplementedError
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return paddle.stack(intermediate), paddle.stack(
intermediate_reference_points)
return output, reference_points
@register
class PETR_DeformableTransformerDecoder(nn.Layer):
__inject__ = ['decoder_layer']
def __init__(self, decoder_layer, num_layers, return_intermediate=False):
super(PETR_DeformableTransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.return_intermediate = return_intermediate
def forward(self,
tgt,
reference_points,
memory,
memory_spatial_shapes,
memory_mask=None,
query_pos_embed=None):
output = tgt
intermediate = []
for lid, layer in enumerate(self.layers):
output = layer(output, reference_points, memory,
memory_spatial_shapes, memory_mask, query_pos_embed)
if self.return_intermediate:
intermediate.append(output)
if self.return_intermediate:
return paddle.stack(intermediate)
return output.unsqueeze(0)
@register
class PETR_DeformableDetrTransformerDecoder(PETR_DeformableTransformerDecoder):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args, return_intermediate=False, **kwargs):
super(PETR_DeformableDetrTransformerDecoder, self).__init__(*args,
**kwargs)
self.return_intermediate = return_intermediate
def forward(self,
query,
*args,
reference_points=None,
valid_ratios=None,
reg_branches=None,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.LayerList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] * \
paddle.concat([valid_ratios, valid_ratios], -1)[:, None]
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * \
valid_ratios[:, None]
output = layer(
output,
*args,
reference_points=reference_points_input,
**kwargs)
if reg_branches is not None:
tmp = reg_branches[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(
reference_points)
new_reference_points = F.sigmoid(new_reference_points)
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[
..., :2] + inverse_sigmoid(reference_points)
new_reference_points = F.sigmoid(new_reference_points)
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return paddle.stack(intermediate), paddle.stack(
intermediate_reference_points)
return output, reference_points
@register
class PETRTransformer(nn.Layer):
"""Implements the PETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
__inject__ = ["encoder", "decoder", "hm_encoder", "refine_decoder"]
def __init__(self,
encoder="",
decoder="",
hm_encoder="",
refine_decoder="",
as_two_stage=True,
num_feature_levels=4,
two_stage_num_proposals=300,
num_keypoints=17,
**kwargs):
super(PETRTransformer, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.num_keypoints = num_keypoints
self.encoder = encoder
self.decoder = decoder
self.embed_dims = self.encoder.embed_dims
self.hm_encoder = hm_encoder
self.refine_decoder = refine_decoder
self.init_layers()
self.init_weights()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
#paddle.create_parameter
self.level_embeds = paddle.create_parameter(
(self.num_feature_levels, self.embed_dims), dtype="float32")
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.refine_query_embedding = nn.Embedding(self.num_keypoints,
self.embed_dims * 2)
else:
self.reference_points = nn.Linear(self.embed_dims,
2 * self.num_keypoints)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.rank() > 1:
xavier_uniform_(p)
if hasattr(p, 'bias') and p.bias is not None:
constant_(p.bais)
for m in self.sublayers():
if isinstance(m, MSDeformableAttention):
m._reset_parameters()
for m in self.sublayers():
if isinstance(m, MultiScaleDeformablePoseAttention):
m.init_weights()
if not self.as_two_stage:
xavier_uniform_(self.reference_points.weight)
constant_(self.reference_points.bias)
normal_(self.level_embeds)
normal_(self.refine_query_embedding.weight)
def gen_encoder_output_proposals(self, memory, memory_padding_mask,
spatial_shapes):
"""Generate proposals from encoded memory.
Args:
memory (Tensor): The output of encoder, has shape
(bs, num_key, embed_dim). num_key is equal the number of points
on feature map from all level.
memory_padding_mask (Tensor): Padding mask for memory.
has shape (bs, num_key).
spatial_shapes (Tensor): The shape of all feature maps.
has shape (num_level, 2).
Returns:
tuple: A tuple of feature map and bbox prediction.
- output_memory (Tensor): The input of decoder, has shape
(bs, num_key, embed_dim). num_key is equal the number of
points on feature map from all levels.
- output_proposals (Tensor): The normalized proposal
after a inverse sigmoid, has shape (bs, num_keys, 4).
"""
N, S, C = memory.shape
proposals = []
_cur = 0
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].reshape(
[N, H, W, 1])
valid_H = paddle.sum(mask_flatten_[:, :, 0, 0], 1)
valid_W = paddle.sum(mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = paddle.meshgrid(
paddle.linspace(
0, H - 1, H, dtype="float32"),
paddle.linspace(
0, W - 1, W, dtype="float32"))
grid = paddle.concat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)],
-1)
scale = paddle.concat(
[valid_W.unsqueeze(-1),
valid_H.unsqueeze(-1)], 1).reshape([N, 1, 1, 2])
grid = (grid.unsqueeze(0).expand((N, -1, -1, -1)) + 0.5) / scale
proposal = grid.reshape([N, -1, 2])
proposals.append(proposal)
_cur += (H * W)
output_proposals = paddle.concat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) &
(output_proposals < 0.99)).all(
-1, keepdim=True).astype("bool")
output_proposals = paddle.log(output_proposals / (1 - output_proposals))
output_proposals = masked_fill(
output_proposals, ~memory_padding_mask.astype("bool").unsqueeze(-1),
float('inf'))
output_proposals = masked_fill(output_proposals,
~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = masked_fill(
output_memory, ~memory_padding_mask.astype("bool").unsqueeze(-1),
float(0))
output_memory = masked_fill(output_memory, ~output_proposals_valid,
float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all feature maps,
has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid points on the
feature map, has shape (bs, num_levels, 2).
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = paddle.meshgrid(
paddle.linspace(
0.5, H - 0.5, H, dtype="float32"),
paddle.linspace(
0.5, W - 0.5, W, dtype="float32"))
ref_y = ref_y.reshape(
(-1, ))[None] / (valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(
(-1, ))[None] / (valid_ratios[:, None, lvl, 0] * W)
ref = paddle.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = paddle.concat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = paddle.sum(mask[:, :, 0].astype('float'), 1)
valid_W = paddle.sum(mask[:, 0, :].astype('float'), 1)
valid_ratio_h = valid_H.astype('float') / H
valid_ratio_w = valid_W.astype('float') / W
valid_ratio = paddle.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self,
proposals,
num_pos_feats=128,
temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = paddle.arange(num_pos_feats, dtype="float32")
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = F.sigmoid(proposals) * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = paddle.stack(
(pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
axis=4).flatten(2)
return pos
def forward(self,
mlvl_feats,
mlvl_masks,
query_embed,
mlvl_pos_embeds,
kpt_branches=None,
cls_branches=None):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from different level.
Each element has shape [bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from different
level used for encoder and decoder, each element has shape
[bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
kpt_branches (obj:`nn.LayerList`): Keypoint Regression heads for
feature maps from each decoder layer. Only would be passed when
`with_box_refine` is Ture. Default to None.
cls_branches (obj:`nn.LayerList`): Classification heads for
feature maps from each decoder layer. Only would be passed when
`as_two_stage` is Ture. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
`return_intermediate_dec` is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of proposals \
generated from encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_kpt_unact: The regression results generated from \
encoder's feature maps., has shape (batch, h*w, K*2).
Only would be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed
) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose((0, 2, 1))
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose((0, 2, 1))
lvl_pos_embed = pos_embed + self.level_embeds[lvl].reshape(
[1, 1, -1])
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = paddle.concat(feat_flatten, 1)
mask_flatten = paddle.concat(mask_flatten, 1)
lvl_pos_embed_flatten = paddle.concat(lvl_pos_embed_flatten, 1)
spatial_shapes_cumsum = paddle.to_tensor(
np.array(spatial_shapes).prod(1).cumsum(0))
spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64")
level_start_index = paddle.concat((paddle.zeros(
(1, ), dtype=spatial_shapes.dtype), spatial_shapes_cumsum[:-1]))
valid_ratios = paddle.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
reference_points = \
self.get_reference_points(spatial_shapes,
valid_ratios)
memory = self.encoder(
src=feat_flatten,
pos_embed=lvl_pos_embed_flatten,
src_mask=mask_flatten,
value_spatial_shapes=spatial_shapes,
reference_points=reference_points,
value_level_start_index=level_start_index,
valid_ratios=valid_ratios)
bs, _, c = memory.shape
hm_proto = None
if self.training:
hm_memory = paddle.slice(
memory,
starts=level_start_index[0],
ends=level_start_index[1],
axes=[1])
hm_pos_embed = paddle.slice(
lvl_pos_embed_flatten,
starts=level_start_index[0],
ends=level_start_index[1],
axes=[1])
hm_mask = paddle.slice(
mask_flatten,
starts=level_start_index[0],
ends=level_start_index[1],
axes=[1])
hm_reference_points = paddle.slice(
reference_points,
starts=level_start_index[0],
ends=level_start_index[1],
axes=[1])[:, :, :1, :]
# official code make a mistake of pos_embed to pose_embed, which disable pos_embed
hm_memory = self.hm_encoder(
src=hm_memory,
pose_embed=hm_pos_embed,
src_mask=hm_mask,
value_spatial_shapes=spatial_shapes[[0]],
reference_points=hm_reference_points,
value_level_start_index=level_start_index[0],
valid_ratios=valid_ratios[:, :1, :])
hm_memory = hm_memory.reshape((bs, spatial_shapes[0, 0],
spatial_shapes[0, 1], -1))
hm_proto = (hm_memory, mlvl_masks[0])
if self.as_two_stage:
output_memory, output_proposals = \
self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes)
enc_outputs_class = cls_branches[self.decoder.num_layers](
output_memory)
enc_outputs_kpt_unact = \
kpt_branches[self.decoder.num_layers](output_memory)
enc_outputs_kpt_unact[..., 0::2] += output_proposals[..., 0:1]
enc_outputs_kpt_unact[..., 1::2] += output_proposals[..., 1:2]
topk = self.two_stage_num_proposals
topk_proposals = paddle.topk(
enc_outputs_class[..., 0], topk, axis=1)[1].unsqueeze(-1)
#paddle.take_along_axis 对应torch.gather
topk_kpts_unact = paddle.take_along_axis(enc_outputs_kpt_unact,
topk_proposals, 1)
topk_kpts_unact = topk_kpts_unact.detach()
reference_points = F.sigmoid(topk_kpts_unact)
init_reference_out = reference_points
# learnable query and query_pos
query_pos, query = paddle.split(
query_embed, query_embed.shape[1] // c, axis=1)
query_pos = query_pos.unsqueeze(0).expand((bs, -1, -1))
query = query.unsqueeze(0).expand((bs, -1, -1))
else:
query_pos, query = paddle.split(
query_embed, query_embed.shape[1] // c, axis=1)
query_pos = query_pos.unsqueeze(0).expand((bs, -1, -1))
query = query.unsqueeze(0).expand((bs, -1, -1))
reference_points = F.sigmoid(self.reference_points(query_pos))
init_reference_out = reference_points
# decoder
inter_states, inter_references = self.decoder(
query=query,
memory=memory,
query_pos_embed=query_pos,
memory_mask=mask_flatten,
reference_points=reference_points,
value_spatial_shapes=spatial_shapes,
value_level_start_index=level_start_index,
valid_ratios=valid_ratios,
kpt_branches=kpt_branches)
inter_references_out = inter_references
if self.as_two_stage:
return inter_states, init_reference_out, \
inter_references_out, enc_outputs_class, \
enc_outputs_kpt_unact, hm_proto, memory
return inter_states, init_reference_out, \
inter_references_out, None, None, None, None, None, hm_proto
def forward_refine(self,
mlvl_masks,
memory,
reference_points_pose,
img_inds,
kpt_branches=None,
**kwargs):
mask_flatten = []
spatial_shapes = []
for lvl, mask in enumerate(mlvl_masks):
bs, h, w = mask.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
mask = mask.flatten(1)
mask_flatten.append(mask)
mask_flatten = paddle.concat(mask_flatten, 1)
spatial_shapes_cumsum = paddle.to_tensor(
np.array(
spatial_shapes, dtype='int64').prod(1).cumsum(0))
spatial_shapes = paddle.to_tensor(spatial_shapes, dtype="int64")
level_start_index = paddle.concat((paddle.zeros(
(1, ), dtype=spatial_shapes.dtype), spatial_shapes_cumsum[:-1]))
valid_ratios = paddle.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
# pose refinement (17 queries corresponding to 17 keypoints)
# learnable query and query_pos
refine_query_embedding = self.refine_query_embedding.weight
query_pos, query = paddle.split(refine_query_embedding, 2, axis=1)
pos_num = reference_points_pose.shape[0]
query_pos = query_pos.unsqueeze(0).expand((pos_num, -1, -1))
query = query.unsqueeze(0).expand((pos_num, -1, -1))
reference_points = reference_points_pose.reshape(
(pos_num, reference_points_pose.shape[1] // 2, 2))
pos_memory = memory[img_inds]
mask_flatten = mask_flatten[img_inds]
valid_ratios = valid_ratios[img_inds]
if img_inds.size == 1:
pos_memory = pos_memory.unsqueeze(0)
mask_flatten = mask_flatten.unsqueeze(0)
valid_ratios = valid_ratios.unsqueeze(0)
inter_states, inter_references = self.refine_decoder(
query=query,
memory=pos_memory,
query_pos_embed=query_pos,
memory_mask=mask_flatten,
reference_points=reference_points,
value_spatial_shapes=spatial_shapes,
value_level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=kpt_branches,
**kwargs)
# [num_decoder, num_query, bs, embed_dim]
init_reference_out = reference_points
return inter_states, init_reference_out, inter_references
......@@ -238,7 +238,7 @@ def draw_pose(image,
'for example: `pip install matplotlib`.')
raise e
skeletons = np.array([item['keypoints'] for item in results])
skeletons = np.array([item['keypoints'] for item in results]).reshape((-1, 51))
kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = int(skeletons.shape[1] / 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册