未验证 提交 bec57bcf 编写于 作者: Z zhiboniu 提交者: GitHub

[cherry-pick]cherry-pick of petr and tinypose3d_human3.6m (#7816)

* keypoint petr (#7774)

* petr train ok

train ok

refix augsize

affine size fix

update msdeformable

fix flip/affine

fix clip

add resize area

add distortion

debug mode

fix pos_inds

update edge joints

update word mistake

* delete extra codes;adapt transformer modify;update code format

* reverse old transformer modify

* integrate datasets

* add config and architecture for human36m (#7802)

* add config and architecture for human36m

* modify TinyPose3DHRNet to support human3.6M dataset

* delete useless class

---------
Co-authored-by: XYZ_916's avatarXYZ <1290573099@qq.com>
上级 93e2d433
...@@ -18,9 +18,9 @@ __pycache__/ ...@@ -18,9 +18,9 @@ __pycache__/
# Distribution / packaging # Distribution / packaging
/bin/ /bin/
/build/ *build/
/develop-eggs/ /develop-eggs/
/dist/ *dist/
/eggs/ /eggs/
/lib/ /lib/
/lib64/ /lib64/
...@@ -30,7 +30,7 @@ __pycache__/ ...@@ -30,7 +30,7 @@ __pycache__/
/parts/ /parts/
/sdist/ /sdist/
/var/ /var/
/*.egg-info/ *.egg-info/
/.installed.cfg /.installed.cfg
/*.egg /*.egg
/.eggs /.eggs
......
...@@ -56,8 +56,10 @@ PaddleDetection 中的关键点检测部分紧跟最先进的算法,包括 Top ...@@ -56,8 +56,10 @@ PaddleDetection 中的关键点检测部分紧跟最先进的算法,包括 Top
## 模型库 ## 模型库
COCO数据集 COCO数据集
| 模型 | 方案 |输入尺寸 | AP(coco val) | 模型下载 | 配置文件 | | 模型 | 方案 |输入尺寸 | AP(coco val) | 模型下载 | 配置文件 |
| :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------| ------- | | :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------| ------- |
| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) |
| HigherHRNet-w32 |Bottom-Up| 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) | | HigherHRNet-w32 |Bottom-Up| 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet-w32 | Bottom-Up| 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) | | HigherHRNet-w32 | Bottom-Up| 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet-w32+SWAHR |Bottom-Up| 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) | | HigherHRNet-w32+SWAHR |Bottom-Up| 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
......
...@@ -62,6 +62,7 @@ At the same time, PaddleDetection provides a self-developed real-time keypoint d ...@@ -62,6 +62,7 @@ At the same time, PaddleDetection provides a self-developed real-time keypoint d
COCO Dataset COCO Dataset
| Model | Input Size | AP(coco val) | Model Download | Config File | | Model | Input Size | AP(coco val) | Model Download | Config File |
| :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- | | :---------------- | -------- | :----------: | :----------------------------------------------------------: | ----------------------------------------------------------- |
| PETR_Res50 |One-Stage| 512 | 65.5 | [petr_res50.pdparams](https://bj.bcebos.com/v1/paddledet/models/keypoint/petr_resnet50_16x2_coco.pdparams) | [config](./petr/petr_resnet50_16x2_coco.yml) |
| HigherHRNet-w32 | 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) | | HigherHRNet-w32 | 512 | 67.1 | [higherhrnet_hrnet_w32_512.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512.yml) |
| HigherHRNet-w32 | 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) | | HigherHRNet-w32 | 640 | 68.3 | [higherhrnet_hrnet_w32_640.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_640.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_640.yml) |
| HigherHRNet-w32+SWAHR | 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) | | HigherHRNet-w32+SWAHR | 512 | 68.9 | [higherhrnet_hrnet_w32_512_swahr.pdparams](https://paddledet.bj.bcebos.com/models/keypoint/higherhrnet_hrnet_w32_512_swahr.pdparams) | [config](./higherhrnet/higherhrnet_hrnet_w32_512_swahr.yml) |
......
use_gpu: true
log_iter: 50
save_dir: output
snapshot_epoch: 1
weights: output/petr_resnet50_16x2_coco/model_final
epoch: 100
num_joints: &num_joints 17
pixel_std: &pixel_std 200
metric: COCO
num_classes: 1
trainsize: &trainsize 512
flip_perm: &flip_perm [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
find_unused_parameters: False
#####model
architecture: PETR
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/PETR_pretrained.pdparams
PETR:
backbone:
name: ResNet
depth: 50
variant: b
norm_type: bn
freeze_norm: True
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
lr_mult_list: [0.1, 0.1, 0.1, 0.1]
neck:
name: ChannelMapper
in_channels: [512, 1024, 2048]
kernel_size: 1
out_channels: 256
norm_type: "gn"
norm_groups: 32
act: None
num_outs: 4
bbox_head:
name: PETRHead
num_query: 300
num_classes: 1 # only person
in_channels: 2048
sync_cls_avg_factor: true
with_kpt_refine: true
transformer:
name: PETRTransformer
as_two_stage: true
encoder:
name: TransformerEncoder
encoder_layer:
name: TransformerEncoderLayer
d_model: 256
attn:
name: MSDeformableAttention
embed_dim: 256
num_heads: 8
num_levels: 4
num_points: 4
dim_feedforward: 1024
dropout: 0.1
num_layers: 6
decoder:
name: PETR_TransformerDecoder
num_layers: 3
return_intermediate: true
decoder_layer:
name: PETR_TransformerDecoderLayer
d_model: 256
dim_feedforward: 1024
dropout: 0.1
self_attn:
name: MultiHeadAttention
embed_dim: 256
num_heads: 8
dropout: 0.1
cross_attn:
name: MultiScaleDeformablePoseAttention
embed_dims: 256
num_heads: 8
num_levels: 4
num_points: 17
hm_encoder:
name: TransformerEncoder
encoder_layer:
name: TransformerEncoderLayer
d_model: 256
attn:
name: MSDeformableAttention
embed_dim: 256
num_heads: 8
num_levels: 1
num_points: 4
dim_feedforward: 1024
dropout: 0.1
num_layers: 1
refine_decoder:
name: PETR_DeformableDetrTransformerDecoder
num_layers: 2
return_intermediate: true
decoder_layer:
name: PETR_TransformerDecoderLayer
d_model: 256
dim_feedforward: 1024
dropout: 0.1
self_attn:
name: MultiHeadAttention
embed_dim: 256
num_heads: 8
dropout: 0.1
cross_attn:
name: MSDeformableAttention
embed_dim: 256
num_levels: 4
positional_encoding:
name: PositionEmbedding
num_pos_feats: 128
normalize: true
offset: -0.5
loss_cls:
name: Weighted_FocalLoss
use_sigmoid: true
gamma: 2.0
alpha: 0.25
loss_weight: 2.0
reduction: "mean"
loss_kpt:
name: L1Loss
loss_weight: 70.0
loss_kpt_rpn:
name: L1Loss
loss_weight: 70.0
loss_oks:
name: OKSLoss
loss_weight: 2.0
loss_hm:
name: CenterFocalLoss
loss_weight: 4.0
loss_kpt_refine:
name: L1Loss
loss_weight: 80.0
loss_oks_refine:
name: OKSLoss
loss_weight: 3.0
assigner:
name: PoseHungarianAssigner
cls_cost:
name: FocalLossCost
weight: 2.0
kpt_cost:
name: KptL1Cost
weight: 70.0
oks_cost:
name: OksCost
weight: 7.0
#####optimizer
LearningRate:
base_lr: 0.0002
schedulers:
- !PiecewiseDecay
milestones: [80]
gamma: 0.1
use_warmup: false
# - !LinearWarmup
# start_factor: 0.001
# steps: 1000
OptimizerBuilder:
clip_grad_by_norm: 0.1
optimizer:
type: AdamW
regularizer:
factor: 0.0001
type: L2
#####data
TrainDataset:
!KeypointBottomUpCocoDataset
image_dir: train2017
anno_path: annotations/person_keypoints_train2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
return_mask: false
EvalDataset:
!KeypointBottomUpCocoDataset
image_dir: val2017
anno_path: annotations/person_keypoints_val2017.json
dataset_dir: dataset/coco
num_joints: *num_joints
test_mode: true
return_mask: false
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 2
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- Decode: {}
- PhotoMetricDistortion:
brightness_delta: 32
contrast_range: [0.5, 1.5]
saturation_range: [0.5, 1.5]
hue_delta: 18
- KeyPointFlip:
flip_prob: 0.5
flip_permutation: *flip_perm
- RandomAffine:
max_degree: 30
scale: [1.0, 1.0]
max_shift: 0.
trainsize: -1
- RandomSelect: { transforms1: [ RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600},
RandomShortSideRangeResize: { scales: [[400, 1400], [1400, 1400]]} ]}
batch_transforms:
- NormalizeImage: {mean: *global_mean, std: *global_std, is_scale: True}
- PadGT: {pad_img: True, minimum_gtnum: 1}
- Permute: {}
batch_size: 2
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- PETR_Resize: {img_scale: [[800, 1333]], keep_ratio: True}
# - MultiscaleTestResize: {origin_target_size: [[800, 1333]], use_flip: false}
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
TestReader:
sample_transforms:
- Decode: {}
- EvalAffine:
size: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
use_gpu: true
log_iter: 5
save_dir: output
snapshot_epoch: 1
weights: output/tinypose3d_human36M/model_final
epoch: 220
num_joints: &num_joints 24
pixel_std: &pixel_std 200
metric: Pose3DEval
num_classes: 1
train_height: &train_height 128
train_width: &train_width 128
trainsize: &trainsize [*train_width, *train_height]
#####model
architecture: TinyPose3DHRNet
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams
TinyPose3DHRNet:
backbone: LiteHRNet
post_process: HR3DNetPostProcess
fc_channel: 1024
num_joints: *num_joints
width: &width 40
loss: Pose3DLoss
LiteHRNet:
network_type: wider_naive
freeze_at: -1
freeze_norm: false
return_idx: [0]
Pose3DLoss:
weight_3d: 1.0
weight_2d: 0.0
#####optimizer
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
milestones: [17, 21]
gamma: 0.1
- !LinearWarmup
start_factor: 0.01
steps: 1000
OptimizerBuilder:
optimizer:
type: Adam
regularizer:
factor: 0.0
type: L2
#####data
TrainDataset:
!Pose3DDataset
dataset_dir: Human3.6M
image_dirs: ["Images"]
anno_list: ['Human3.6m_train.json']
num_joints: *num_joints
test_mode: False
EvalDataset:
!Pose3DDataset
dataset_dir: Human3.6M
image_dirs: ["Images"]
anno_list: ['Human3.6m_valid.json']
num_joints: *num_joints
test_mode: True
TestDataset:
!ImageFolder
anno_path: dataset/coco/keypoint_imagelist.txt
worker_num: 4
global_mean: &global_mean [0.485, 0.456, 0.406]
global_std: &global_std [0.229, 0.224, 0.225]
TrainReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [0.5, 30] #[prob, rotate range]
scale: [0.5, 0.25] #[prob, scale range]
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- SinglePoseAffine:
trainsize: *trainsize
rotate: [0., 30]
scale: [0., 0.25]
batch_transforms:
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 128
TestReader:
inputs_def:
image_shape: [3, *train_height, *train_width]
sample_transforms:
- Decode: {}
- TopDownEvalAffine:
trainsize: *trainsize
- NormalizeImage:
mean: *global_mean
std: *global_std
is_scale: true
- Permute: {}
batch_size: 1
fuse_normalize: false
...@@ -82,7 +82,7 @@ MPII keypoint indexes: ...@@ -82,7 +82,7 @@ MPII keypoint indexes:
``` ```
{ {
'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
'joints': [ 'gt_joints': [
[-1.0, -1.0], [-1.0, -1.0],
[-1.0, -1.0], [-1.0, -1.0],
[-1.0, -1.0], [-1.0, -1.0],
......
...@@ -82,7 +82,7 @@ The following example takes a parsed annotation information to illustrate the co ...@@ -82,7 +82,7 @@ The following example takes a parsed annotation information to illustrate the co
``` ```
{ {
'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints_vis': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
'joints': [ 'gt_joints': [
[-1.0, -1.0], [-1.0, -1.0],
[-1.0, -1.0], [-1.0, -1.0],
[-1.0, -1.0], [-1.0, -1.0],
......
...@@ -80,7 +80,8 @@ class KeypointBottomUpBaseDataset(DetDataset): ...@@ -80,7 +80,8 @@ class KeypointBottomUpBaseDataset(DetDataset):
records = copy.deepcopy(self._get_imganno(idx)) records = copy.deepcopy(self._get_imganno(idx))
records['image'] = cv2.imread(records['image_file']) records['image'] = cv2.imread(records['image_file'])
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB) records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
records['mask'] = (records['mask'] + 0).astype('uint8') if 'mask' in records:
records['mask'] = (records['mask'] + 0).astype('uint8')
records = self.transform(records) records = self.transform(records)
return records return records
...@@ -135,24 +136,37 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ...@@ -135,24 +136,37 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
num_joints, num_joints,
transform=[], transform=[],
shard=[0, 1], shard=[0, 1],
test_mode=False): test_mode=False,
return_mask=True,
return_bbox=True,
return_area=True,
return_class=True):
super().__init__(dataset_dir, image_dir, anno_path, num_joints, super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode) transform, shard, test_mode)
self.ann_file = os.path.join(dataset_dir, anno_path) self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard self.shard = shard
self.test_mode = test_mode self.test_mode = test_mode
self.return_mask = return_mask
self.return_bbox = return_bbox
self.return_area = return_area
self.return_class = return_class
def parse_dataset(self): def parse_dataset(self):
self.coco = COCO(self.ann_file) self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds() self.img_ids = self.coco.getImgIds()
if not self.test_mode: if not self.test_mode:
self.img_ids = [ self.img_ids_tmp = []
img_id for img_id in self.img_ids for img_id in self.img_ids:
if len(self.coco.getAnnIds( ann_ids = self.coco.getAnnIds(imgIds=img_id)
imgIds=img_id, iscrowd=None)) > 0 anno = self.coco.loadAnns(ann_ids)
] anno = [obj for obj in anno if obj['iscrowd'] == 0]
if len(anno) == 0:
continue
self.img_ids_tmp.append(img_id)
self.img_ids = self.img_ids_tmp
blocknum = int(len(self.img_ids) / self.shard[1]) blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * ( self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))] self.shard[0] + 1))]
...@@ -199,21 +213,31 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ...@@ -199,21 +213,31 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
ann_ids = coco.getAnnIds(imgIds=img_id) ann_ids = coco.getAnnIds(imgIds=img_id)
anno = coco.loadAnns(ann_ids) anno = coco.loadAnns(ann_ids)
mask = self._get_mask(anno, idx)
anno = [ anno = [
obj for obj in anno obj for obj in anno
if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0 if obj['iscrowd'] == 0 and obj['num_keypoints'] > 0
] ]
db_rec = {}
joints, orgsize = self._get_joints(anno, idx) joints, orgsize = self._get_joints(anno, idx)
db_rec['gt_joints'] = joints
db_rec['im_shape'] = orgsize
if self.return_bbox:
db_rec['gt_bbox'] = self._get_bboxs(anno, idx)
if self.return_class:
db_rec['gt_class'] = self._get_labels(anno, idx)
if self.return_area:
db_rec['gt_areas'] = self._get_areas(anno, idx)
if self.return_mask:
db_rec['mask'] = self._get_mask(anno, idx)
db_rec = {}
db_rec['im_id'] = img_id db_rec['im_id'] = img_id
db_rec['image_file'] = os.path.join(self.img_prefix, db_rec['image_file'] = os.path.join(self.img_prefix,
self.id2name[img_id]) self.id2name[img_id])
db_rec['mask'] = mask
db_rec['joints'] = joints
db_rec['im_shape'] = orgsize
return db_rec return db_rec
...@@ -229,12 +253,41 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ...@@ -229,12 +253,41 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
np.array(obj['keypoints']).reshape([-1, 3]) np.array(obj['keypoints']).reshape([-1, 3])
img_info = self.coco.loadImgs(self.img_ids[idx])[0] img_info = self.coco.loadImgs(self.img_ids[idx])[0]
joints[..., 0] /= img_info['width'] orgsize = np.array([img_info['height'], img_info['width'], 1])
joints[..., 1] /= img_info['height']
orgsize = np.array([img_info['height'], img_info['width']])
return joints, orgsize return joints, orgsize
def _get_bboxs(self, anno, idx):
num_people = len(anno)
gt_bboxes = np.zeros((num_people, 4), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'bbox' in obj:
gt_bboxes[idx, :] = obj['bbox']
gt_bboxes[:, 2] += gt_bboxes[:, 0]
gt_bboxes[:, 3] += gt_bboxes[:, 1]
return gt_bboxes
def _get_labels(self, anno, idx):
num_people = len(anno)
gt_labels = np.zeros((num_people, 1), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'category_id' in obj:
catid = obj['category_id']
gt_labels[idx, 0] = self.catid2clsid[catid]
return gt_labels
def _get_areas(self, anno, idx):
num_people = len(anno)
gt_areas = np.zeros((num_people, ), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'area' in obj:
gt_areas[idx, ] = obj['area']
return gt_areas
def _get_mask(self, anno, idx): def _get_mask(self, anno, idx):
"""Get ignore masks to mask out losses.""" """Get ignore masks to mask out losses."""
coco = self.coco coco = self.coco
...@@ -506,7 +559,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): ...@@ -506,7 +559,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
'image_file': os.path.join(self.img_prefix, file_name), 'image_file': os.path.join(self.img_prefix, file_name),
'center': center, 'center': center,
'scale': scale, 'scale': scale,
'joints': joints, 'gt_joints': joints,
'joints_vis': joints_vis, 'joints_vis': joints_vis,
'im_id': im_id, 'im_id': im_id,
}) })
...@@ -570,7 +623,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): ...@@ -570,7 +623,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
'center': center, 'center': center,
'scale': scale, 'scale': scale,
'score': score, 'score': score,
'joints': joints, 'gt_joints': joints,
'joints_vis': joints_vis, 'joints_vis': joints_vis,
}) })
...@@ -647,8 +700,8 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): ...@@ -647,8 +700,8 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
(self.ann_info['num_joints'], 3), dtype=np.float32) (self.ann_info['num_joints'], 3), dtype=np.float32)
joints_vis = np.zeros( joints_vis = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32) (self.ann_info['num_joints'], 3), dtype=np.float32)
if 'joints' in a: if 'gt_joints' in a:
joints_ = np.array(a['joints']) joints_ = np.array(a['gt_joints'])
joints_[:, 0:2] = joints_[:, 0:2] - 1 joints_[:, 0:2] = joints_[:, 0:2] - 1
joints_vis_ = np.array(a['joints_vis']) joints_vis_ = np.array(a['joints_vis'])
assert len(joints_) == self.ann_info[ assert len(joints_) == self.ann_info[
...@@ -664,7 +717,7 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): ...@@ -664,7 +717,7 @@ class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
'im_id': im_id, 'im_id': im_id,
'center': c, 'center': c,
'scale': s, 'scale': s,
'joints': joints, 'gt_joints': joints,
'joints_vis': joints_vis 'joints_vis': joints_vis
}) })
print("number length: {}".format(len(gt_db))) print("number length: {}".format(len(gt_db)))
......
...@@ -1102,13 +1102,115 @@ class PadGT(BaseOperator): ...@@ -1102,13 +1102,115 @@ class PadGT(BaseOperator):
1 means bbox, 0 means no bbox. 1 means bbox, 0 means no bbox.
""" """
def __init__(self, return_gt_mask=True): def __init__(self, return_gt_mask=True, pad_img=False, minimum_gtnum=0):
super(PadGT, self).__init__() super(PadGT, self).__init__()
self.return_gt_mask = return_gt_mask self.return_gt_mask = return_gt_mask
self.pad_img = pad_img
self.minimum_gtnum = minimum_gtnum
def _impad(self, img: np.ndarray,
*,
shape = None,
padding = None,
pad_val = 0,
padding_mode = 'constant') -> np.ndarray:
"""Pad the given image to a certain shape or pad on all sides with
specified padding mode and padding value.
Args:
img (ndarray): Image to be padded.
shape (tuple[int]): Expected padding shape (h, w). Default: None.
padding (int or tuple[int]): Padding on each border. If a single int is
provided this is used to pad all borders. If tuple of length 2 is
provided this is the padding on left/right and top/bottom
respectively. If a tuple of length 4 is provided this is the
padding for the left, top, right and bottom borders respectively.
Default: None. Note that `shape` and `padding` can not be both
set.
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas when padding_mode is 'constant'. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
Returns:
ndarray: The padded image.
"""
assert (shape is not None) ^ (padding is not None)
if shape is not None:
width = max(shape[1] - img.shape[1], 0)
height = max(shape[0] - img.shape[0], 0)
padding = (0, 0, int(width), int(height))
# check pad_val
import numbers
if isinstance(pad_val, tuple):
assert len(pad_val) == img.shape[-1]
elif not isinstance(pad_val, numbers.Number):
raise TypeError('pad_val must be a int or a tuple. '
f'But received {type(pad_val)}')
# check padding
if isinstance(padding, tuple) and len(padding) in [2, 4]:
if len(padding) == 2:
padding = (padding[0], padding[1], padding[0], padding[1])
elif isinstance(padding, numbers.Number):
padding = (padding, padding, padding, padding)
else:
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
f'But received {padding}')
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
border_type = {
'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_REFLECT_101,
'symmetric': cv2.BORDER_REFLECT
}
img = cv2.copyMakeBorder(
img,
padding[1],
padding[3],
padding[0],
padding[2],
border_type[padding_mode],
value=pad_val)
return img
def checkmaxshape(self, samples):
maxh, maxw = 0, 0
for sample in samples:
h,w = sample['im_shape']
if h>maxh:
maxh = h
if w>maxw:
maxw = w
return (maxh, maxw)
def __call__(self, samples, context=None): def __call__(self, samples, context=None):
num_max_boxes = max([len(s['gt_bbox']) for s in samples]) num_max_boxes = max([len(s['gt_bbox']) for s in samples])
num_max_boxes = max(self.minimum_gtnum, num_max_boxes)
if self.pad_img:
maxshape = self.checkmaxshape(samples)
for sample in samples: for sample in samples:
if self.pad_img:
img = sample['image']
padimg = self._impad(img, shape=maxshape)
sample['image'] = padimg
if self.return_gt_mask: if self.return_gt_mask:
sample['pad_gt_mask'] = np.zeros( sample['pad_gt_mask'] = np.zeros(
(num_max_boxes, 1), dtype=np.float32) (num_max_boxes, 1), dtype=np.float32)
...@@ -1142,6 +1244,17 @@ class PadGT(BaseOperator): ...@@ -1142,6 +1244,17 @@ class PadGT(BaseOperator):
if num_gt > 0: if num_gt > 0:
pad_diff[:num_gt] = sample['difficult'] pad_diff[:num_gt] = sample['difficult']
sample['difficult'] = pad_diff sample['difficult'] = pad_diff
if 'gt_joints' in sample:
num_joints = sample['gt_joints'].shape[1]
pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32)
if num_gt > 0:
pad_gt_joints[:num_gt] = sample['gt_joints']
sample['gt_joints'] = pad_gt_joints
if 'gt_areas' in sample:
pad_gt_areas = np.zeros((num_max_boxes, 1), dtype=np.float32)
if num_gt > 0:
pad_gt_areas[:num_gt, 0] = sample['gt_areas']
sample['gt_areas'] = pad_gt_areas
return samples return samples
......
...@@ -594,6 +594,108 @@ class RandomDistort(BaseOperator): ...@@ -594,6 +594,108 @@ class RandomDistort(BaseOperator):
return sample return sample
@register_op
class PhotoMetricDistortion(BaseOperator):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
super(PhotoMetricDistortion, self).__init__()
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def apply(self, results, context=None):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = results['image']
img = img.astype(np.float32)
# random brightness
if np.random.randint(2):
delta = np.random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = np.random.randint(2)
if mode == 1:
if np.random.randint(2):
alpha = np.random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# random saturation
if np.random.randint(2):
img[..., 1] *= np.random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if np.random.randint(2):
img[..., 0] += np.random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
# random contrast
if mode == 0:
if np.random.randint(2):
alpha = np.random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if np.random.randint(2):
img = img[..., np.random.permutation(3)]
results['image'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
repr_str += 'contrast_range='
repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
repr_str += 'saturation_range='
repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
repr_str += f'hue_delta={self.hue_delta})'
return repr_str
@register_op @register_op
class AutoAugment(BaseOperator): class AutoAugment(BaseOperator):
def __init__(self, autoaug_type="v1"): def __init__(self, autoaug_type="v1"):
...@@ -771,6 +873,19 @@ class Resize(BaseOperator): ...@@ -771,6 +873,19 @@ class Resize(BaseOperator):
bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h) bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
return bbox return bbox
def apply_area(self, area, scale):
im_scale_x, im_scale_y = scale
return area * im_scale_x * im_scale_y
def apply_joints(self, joints, scale, size):
im_scale_x, im_scale_y = scale
resize_w, resize_h = size
joints[..., 0] *= im_scale_x
joints[..., 1] *= im_scale_y
joints[..., 0] = np.clip(joints[..., 0], 0, resize_w)
joints[..., 1] = np.clip(joints[..., 1], 0, resize_h)
return joints
def apply_segm(self, segms, im_size, scale): def apply_segm(self, segms, im_size, scale):
def _resize_poly(poly, im_scale_x, im_scale_y): def _resize_poly(poly, im_scale_x, im_scale_y):
resized_poly = np.array(poly).astype('float32') resized_poly = np.array(poly).astype('float32')
...@@ -833,8 +948,8 @@ class Resize(BaseOperator): ...@@ -833,8 +948,8 @@ class Resize(BaseOperator):
im_scale = min(target_size_min / im_size_min, im_scale = min(target_size_min / im_size_min,
target_size_max / im_size_max) target_size_max / im_size_max)
resize_h = im_scale * float(im_shape[0]) resize_h = int(im_scale * float(im_shape[0]) + 0.5)
resize_w = im_scale * float(im_shape[1]) resize_w = int(im_scale * float(im_shape[1]) + 0.5)
im_scale_x = im_scale im_scale_x = im_scale
im_scale_y = im_scale im_scale_y = im_scale
...@@ -878,6 +993,11 @@ class Resize(BaseOperator): ...@@ -878,6 +993,11 @@ class Resize(BaseOperator):
[im_scale_x, im_scale_y], [im_scale_x, im_scale_y],
[resize_w, resize_h]) [resize_w, resize_h])
# apply areas
if 'gt_areas' in sample:
sample['gt_areas'] = self.apply_area(sample['gt_areas'],
[im_scale_x, im_scale_y])
# apply polygon # apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0: if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2], sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2],
...@@ -911,6 +1031,11 @@ class Resize(BaseOperator): ...@@ -911,6 +1031,11 @@ class Resize(BaseOperator):
] ]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8) sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(sample['gt_joints'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
return sample return sample
...@@ -1362,7 +1487,8 @@ class RandomCrop(BaseOperator): ...@@ -1362,7 +1487,8 @@ class RandomCrop(BaseOperator):
num_attempts=50, num_attempts=50,
allow_no_crop=True, allow_no_crop=True,
cover_all_box=False, cover_all_box=False,
is_mask_crop=False): is_mask_crop=False,
ioumode="iou"):
super(RandomCrop, self).__init__() super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio self.aspect_ratio = aspect_ratio
self.thresholds = thresholds self.thresholds = thresholds
...@@ -1371,6 +1497,7 @@ class RandomCrop(BaseOperator): ...@@ -1371,6 +1497,7 @@ class RandomCrop(BaseOperator):
self.allow_no_crop = allow_no_crop self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop self.is_mask_crop = is_mask_crop
self.ioumode = ioumode
def crop_segms(self, segms, valid_ids, crop, height, width): def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop): def _crop_poly(segm, crop):
...@@ -1516,9 +1643,14 @@ class RandomCrop(BaseOperator): ...@@ -1516,9 +1643,14 @@ class RandomCrop(BaseOperator):
crop_y = np.random.randint(0, h - crop_h) crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w) crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = self._iou_matrix( if self.ioumode == "iof":
gt_bbox, np.array( iou = self._gtcropiou_matrix(
[crop_box], dtype=np.float32)) gt_bbox, np.array(
[crop_box], dtype=np.float32))
elif self.ioumode == "iou":
iou = self._iou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
if iou.max() < thresh: if iou.max() < thresh:
continue continue
...@@ -1582,6 +1714,10 @@ class RandomCrop(BaseOperator): ...@@ -1582,6 +1714,10 @@ class RandomCrop(BaseOperator):
sample['difficult'] = np.take( sample['difficult'] = np.take(
sample['difficult'], valid_ids, axis=0) sample['difficult'], valid_ids, axis=0)
if 'gt_joints' in sample:
sample['gt_joints'] = self._crop_joints(sample['gt_joints'],
crop_box)
return sample return sample
return sample return sample
...@@ -1596,6 +1732,16 @@ class RandomCrop(BaseOperator): ...@@ -1596,6 +1732,16 @@ class RandomCrop(BaseOperator):
area_o = (area_a[:, np.newaxis] + area_b - area_i) area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10) return area_i / (area_o + 1e-10)
def _gtcropiou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_a + 1e-10)
def _crop_box_with_center_constraint(self, box, crop): def _crop_box_with_center_constraint(self, box, crop):
cropped_box = box.copy() cropped_box = box.copy()
...@@ -1620,6 +1766,16 @@ class RandomCrop(BaseOperator): ...@@ -1620,6 +1766,16 @@ class RandomCrop(BaseOperator):
x1, y1, x2, y2 = crop x1, y1, x2, y2 = crop
return segm[:, y1:y2, x1:x2] return segm[:, y1:y2, x1:x2]
def _crop_joints(self, joints, crop):
x1, y1, x2, y2 = crop
joints[joints[..., 0] > x2, :] = 0
joints[joints[..., 1] > y2, :] = 0
joints[joints[..., 0] < x1, :] = 0
joints[joints[..., 1] < y1, :] = 0
joints[..., 0] -= x1
joints[..., 1] -= y1
return joints
@register_op @register_op
class RandomScaledCrop(BaseOperator): class RandomScaledCrop(BaseOperator):
...@@ -1648,8 +1804,8 @@ class RandomScaledCrop(BaseOperator): ...@@ -1648,8 +1804,8 @@ class RandomScaledCrop(BaseOperator):
random_dim = int(dim * random_scale) random_dim = int(dim * random_scale)
dim_max = max(h, w) dim_max = max(h, w)
scale = random_dim / dim_max scale = random_dim / dim_max
resize_w = w * scale resize_w = int(w * scale + 0.5)
resize_h = h * scale resize_h = int(h * scale + 0.5)
offset_x = int(max(0, np.random.uniform(0., resize_w - dim))) offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
offset_y = int(max(0, np.random.uniform(0., resize_h - dim))) offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
...@@ -2316,25 +2472,26 @@ class RandomResizeCrop(BaseOperator): ...@@ -2316,25 +2472,26 @@ class RandomResizeCrop(BaseOperator):
is_mask_crop(bool): whether crop the segmentation. is_mask_crop(bool): whether crop the segmentation.
""" """
def __init__( def __init__(self,
self, resizes,
resizes, cropsizes,
cropsizes, prob=0.5,
prob=0.5, mode='short',
mode='short', keep_ratio=True,
keep_ratio=True, interp=cv2.INTER_LINEAR,
interp=cv2.INTER_LINEAR, num_attempts=3,
num_attempts=3, cover_all_box=False,
cover_all_box=False, allow_no_crop=False,
allow_no_crop=False, thresholds=[0.3, 0.5, 0.7],
thresholds=[0.3, 0.5, 0.7], is_mask_crop=False,
is_mask_crop=False, ): ioumode="iou"):
super(RandomResizeCrop, self).__init__() super(RandomResizeCrop, self).__init__()
self.resizes = resizes self.resizes = resizes
self.cropsizes = cropsizes self.cropsizes = cropsizes
self.prob = prob self.prob = prob
self.mode = mode self.mode = mode
self.ioumode = ioumode
self.resizer = Resize(0, keep_ratio=keep_ratio, interp=interp) self.resizer = Resize(0, keep_ratio=keep_ratio, interp=interp)
self.croper = RandomCrop( self.croper = RandomCrop(
...@@ -2389,9 +2546,14 @@ class RandomResizeCrop(BaseOperator): ...@@ -2389,9 +2546,14 @@ class RandomResizeCrop(BaseOperator):
crop_x = random.randint(0, w - crop_w) crop_x = random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = self._iou_matrix( if self.ioumode == "iof":
gt_bbox, np.array( iou = self._gtcropiou_matrix(
[crop_box], dtype=np.float32)) gt_bbox, np.array(
[crop_box], dtype=np.float32))
elif self.ioumode == "iou":
iou = self._iou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
if iou.max() < thresh: if iou.max() < thresh:
continue continue
...@@ -2447,6 +2609,14 @@ class RandomResizeCrop(BaseOperator): ...@@ -2447,6 +2609,14 @@ class RandomResizeCrop(BaseOperator):
if 'is_crowd' in sample: if 'is_crowd' in sample:
sample['is_crowd'] = np.take( sample['is_crowd'] = np.take(
sample['is_crowd'], valid_ids, axis=0) sample['is_crowd'], valid_ids, axis=0)
if 'gt_areas' in sample:
sample['gt_areas'] = np.take(
sample['gt_areas'], valid_ids, axis=0)
if 'gt_joints' in sample:
gt_joints = self._crop_joints(sample['gt_joints'], crop_box)
sample['gt_joints'] = gt_joints[valid_ids]
return sample return sample
return sample return sample
...@@ -2479,8 +2649,8 @@ class RandomResizeCrop(BaseOperator): ...@@ -2479,8 +2649,8 @@ class RandomResizeCrop(BaseOperator):
im_scale = max(target_size_min / im_size_min, im_scale = max(target_size_min / im_size_min,
target_size_max / im_size_max) target_size_max / im_size_max)
resize_h = im_scale * float(im_shape[0]) resize_h = int(im_scale * float(im_shape[0]) + 0.5)
resize_w = im_scale * float(im_shape[1]) resize_w = int(im_scale * float(im_shape[1]) + 0.5)
im_scale_x = im_scale im_scale_x = im_scale
im_scale_y = im_scale im_scale_y = im_scale
...@@ -2540,6 +2710,11 @@ class RandomResizeCrop(BaseOperator): ...@@ -2540,6 +2710,11 @@ class RandomResizeCrop(BaseOperator):
] ]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8) sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(sample['gt_joints'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
return sample return sample
...@@ -2612,10 +2787,10 @@ class RandomShortSideResize(BaseOperator): ...@@ -2612,10 +2787,10 @@ class RandomShortSideResize(BaseOperator):
if w < h: if w < h:
ow = size ow = size
oh = int(size * h / w) oh = int(round(size * h / w))
else: else:
oh = size oh = size
ow = int(size * w / h) ow = int(round(size * w / h))
return (ow, oh) return (ow, oh)
...@@ -2672,6 +2847,16 @@ class RandomShortSideResize(BaseOperator): ...@@ -2672,6 +2847,16 @@ class RandomShortSideResize(BaseOperator):
for gt_segm in sample['gt_segm'] for gt_segm in sample['gt_segm']
] ]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8) sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if 'gt_joints' in sample:
sample['gt_joints'] = self.apply_joints(
sample['gt_joints'], [im_scale_x, im_scale_y], target_size)
# apply areas
if 'gt_areas' in sample:
sample['gt_areas'] = self.apply_area(sample['gt_areas'],
[im_scale_x, im_scale_y])
return sample return sample
def apply_bbox(self, bbox, scale, size): def apply_bbox(self, bbox, scale, size):
...@@ -2683,6 +2868,23 @@ class RandomShortSideResize(BaseOperator): ...@@ -2683,6 +2868,23 @@ class RandomShortSideResize(BaseOperator):
bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h) bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
return bbox.astype('float32') return bbox.astype('float32')
def apply_joints(self, joints, scale, size):
im_scale_x, im_scale_y = scale
resize_w, resize_h = size
joints[..., 0] *= im_scale_x
joints[..., 1] *= im_scale_y
# joints[joints[..., 0] >= resize_w, :] = 0
# joints[joints[..., 1] >= resize_h, :] = 0
# joints[joints[..., 0] < 0, :] = 0
# joints[joints[..., 1] < 0, :] = 0
joints[..., 0] = np.clip(joints[..., 0], 0, resize_w)
joints[..., 1] = np.clip(joints[..., 1], 0, resize_h)
return joints
def apply_area(self, area, scale):
im_scale_x, im_scale_y = scale
return area * im_scale_x * im_scale_y
def apply_segm(self, segms, im_size, scale): def apply_segm(self, segms, im_size, scale):
def _resize_poly(poly, im_scale_x, im_scale_y): def _resize_poly(poly, im_scale_x, im_scale_y):
resized_poly = np.array(poly).astype('float32') resized_poly = np.array(poly).astype('float32')
...@@ -2730,6 +2932,44 @@ class RandomShortSideResize(BaseOperator): ...@@ -2730,6 +2932,44 @@ class RandomShortSideResize(BaseOperator):
return self.resize(sample, target_size, self.max_size, interp) return self.resize(sample, target_size, self.max_size, interp)
@register_op
class RandomShortSideRangeResize(RandomShortSideResize):
def __init__(self, scales, interp=cv2.INTER_LINEAR, random_interp=False):
"""
Resize the image randomly according to the short side. If max_size is not None,
the long side is scaled according to max_size. The whole process will be keep ratio.
Args:
short_side_sizes (list|tuple): Image target short side size.
interp (int): The interpolation method.
random_interp (bool): Whether random select interpolation method.
"""
super(RandomShortSideRangeResize, self).__init__(scales, None, interp,
random_interp)
assert isinstance(scales,
Sequence), "short_side_sizes must be List or Tuple"
self.scales = scales
def random_sample(self, img_scales):
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long), max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short), max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale
def apply(self, sample, context=None):
long_edge, short_edge = self.random_sample(self.short_side_sizes)
# print("target size:{}".format((long_edge, short_edge)))
interp = random.choice(
self.interps) if self.random_interp else self.interp
return self.resize(sample, short_edge, long_edge, interp)
@register_op @register_op
class RandomSizeCrop(BaseOperator): class RandomSizeCrop(BaseOperator):
""" """
...@@ -2805,6 +3045,9 @@ class RandomSizeCrop(BaseOperator): ...@@ -2805,6 +3045,9 @@ class RandomSizeCrop(BaseOperator):
sample['is_crowd'] = sample['is_crowd'][keep_index] if len( sample['is_crowd'] = sample['is_crowd'][keep_index] if len(
keep_index) > 0 else np.zeros( keep_index) > 0 else np.zeros(
[0, 1], dtype=np.float32) [0, 1], dtype=np.float32)
if 'gt_areas' in sample:
sample['gt_areas'] = np.take(
sample['gt_areas'], keep_index, axis=0)
image_shape = sample['image'].shape[:2] image_shape = sample['image'].shape[:2]
sample['image'] = self.paddle_crop(sample['image'], *region) sample['image'] = self.paddle_crop(sample['image'], *region)
...@@ -2826,6 +3069,12 @@ class RandomSizeCrop(BaseOperator): ...@@ -2826,6 +3069,12 @@ class RandomSizeCrop(BaseOperator):
if keep_index is not None and len(keep_index) > 0: if keep_index is not None and len(keep_index) > 0:
sample['gt_segm'] = sample['gt_segm'][keep_index] sample['gt_segm'] = sample['gt_segm'][keep_index]
if 'gt_joints' in sample:
gt_joints = self._crop_joints(sample['gt_joints'], region)
sample['gt_joints'] = gt_joints
if keep_index is not None:
sample['gt_joints'] = sample['gt_joints'][keep_index]
return sample return sample
def apply_bbox(self, bbox, region): def apply_bbox(self, bbox, region):
...@@ -2836,6 +3085,19 @@ class RandomSizeCrop(BaseOperator): ...@@ -2836,6 +3085,19 @@ class RandomSizeCrop(BaseOperator):
crop_bbox = crop_bbox.clip(min=0) crop_bbox = crop_bbox.clip(min=0)
return crop_bbox.reshape([-1, 4]).astype('float32') return crop_bbox.reshape([-1, 4]).astype('float32')
def _crop_joints(self, joints, region):
y1, x1, h, w = region
x2 = x1 + w
y2 = y1 + h
# x1, y1, x2, y2 = crop
joints[..., 0] -= x1
joints[..., 1] -= y1
joints[joints[..., 0] > w, :] = 0
joints[joints[..., 1] > h, :] = 0
joints[joints[..., 0] < 0, :] = 0
joints[joints[..., 1] < 0, :] = 0
return joints
def apply_segm(self, segms, region, image_shape): def apply_segm(self, segms, region, image_shape):
def _crop_poly(segm, crop): def _crop_poly(segm, crop):
xmin, ymin, xmax, ymax = crop xmin, ymin, xmax, ymax = crop
......
...@@ -72,3 +72,4 @@ from .yolof import * ...@@ -72,3 +72,4 @@ from .yolof import *
from .pose3d_metro import * from .pose3d_metro import *
from .centertrack import * from .centertrack import *
from .queryinst import * from .queryinst import *
from .keypoint_petr import *
...@@ -394,6 +394,7 @@ class TinyPose3DHRNet(BaseArch): ...@@ -394,6 +394,7 @@ class TinyPose3DHRNet(BaseArch):
def __init__(self, def __init__(self,
width, width,
num_joints, num_joints,
fc_channel=768,
backbone='HRNet', backbone='HRNet',
loss='KeyPointRegressionMSELoss', loss='KeyPointRegressionMSELoss',
post_process=TinyPose3DPostProcess): post_process=TinyPose3DPostProcess):
...@@ -411,21 +412,13 @@ class TinyPose3DHRNet(BaseArch): ...@@ -411,21 +412,13 @@ class TinyPose3DHRNet(BaseArch):
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True) self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
self.final_conv_new = L.Conv2d(
width, num_joints * 32, 1, 1, 0, bias=True)
self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3) self.flatten = paddle.nn.Flatten(start_axis=2, stop_axis=3)
self.fc1 = paddle.nn.Linear(768, 256) self.fc1 = paddle.nn.Linear(fc_channel, 256)
self.act1 = paddle.nn.ReLU() self.act1 = paddle.nn.ReLU()
self.fc2 = paddle.nn.Linear(256, 64) self.fc2 = paddle.nn.Linear(256, 64)
self.act2 = paddle.nn.ReLU() self.act2 = paddle.nn.ReLU()
self.fc3 = paddle.nn.Linear(64, 3) self.fc3 = paddle.nn.Linear(64, 3)
# for human3.6M
self.fc1_1 = paddle.nn.Linear(3136, 1024)
self.fc2_1 = paddle.nn.Linear(1024, 256)
self.fc3_1 = paddle.nn.Linear(256, 3)
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
# backbone # backbone
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/hikvision-research/opera/blob/main/opera/models/detectors/petr.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
from .. import layers as L
__all__ = ['PETR']
@register
class PETR(BaseArch):
__category__ = 'architecture'
__inject__ = ['backbone', 'neck', 'bbox_head']
def __init__(self,
backbone='ResNet',
neck='ChannelMapper',
bbox_head='PETRHead'):
"""
PETR, see https://openaccess.thecvf.com/content/CVPR2022/papers/Shi_End-to-End_Multi-Person_Pose_Estimation_With_Transformers_CVPR_2022_paper.pdf
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck between backbone and head
bbox_head (nn.Layer): model output and loss
"""
super(PETR, self).__init__()
self.backbone = backbone
if neck is not None:
self.with_neck = True
self.neck = neck
self.bbox_head = bbox_head
self.deploy = False
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def get_inputs(self):
img_metas = []
gt_bboxes = []
gt_labels = []
gt_keypoints = []
gt_areas = []
pad_gt_mask = self.inputs['pad_gt_mask'].astype("bool").squeeze(-1)
for idx, im_shape in enumerate(self.inputs['im_shape']):
img_meta = {
'img_shape': im_shape.astype("int32").tolist() + [1, ],
'batch_input_shape': self.inputs['image'].shape[-2:],
'image_name': self.inputs['image_file'][idx]
}
img_metas.append(img_meta)
if (not pad_gt_mask[idx].any()):
gt_keypoints.append(self.inputs['gt_joints'][idx][:1])
gt_labels.append(self.inputs['gt_class'][idx][:1])
gt_bboxes.append(self.inputs['gt_bbox'][idx][:1])
gt_areas.append(self.inputs['gt_areas'][idx][:1])
continue
gt_keypoints.append(self.inputs['gt_joints'][idx][pad_gt_mask[idx]])
gt_labels.append(self.inputs['gt_class'][idx][pad_gt_mask[idx]])
gt_bboxes.append(self.inputs['gt_bbox'][idx][pad_gt_mask[idx]])
gt_areas.append(self.inputs['gt_areas'][idx][pad_gt_mask[idx]])
return img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas
def get_loss(self):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
gt_keypoints (list[Tensor]): Each item are the truth keypoints for
each image in [p^{1}_x, p^{1}_y, p^{1}_v, ..., p^{K}_x,
p^{K}_y, p^{K}_v] format.
gt_areas (list[Tensor]): mask areas corresponding to each box.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
img_metas, gt_bboxes, gt_labels, gt_keypoints, gt_areas = self.get_inputs(
)
gt_bboxes_ignore = getattr(self.inputs, 'gt_bboxes_ignore', None)
x = self.extract_feat(self.inputs)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_keypoints, gt_areas,
gt_bboxes_ignore)
loss = 0
for k, v in losses.items():
loss += v
losses['loss'] = loss
return losses
def get_pred_numpy(self):
"""Used for computing network flops.
"""
img = self.inputs['image']
batch_size, _, height, width = img.shape
dummy_img_metas = [
dict(
batch_input_shape=(height, width),
img_shape=(height, width, 3),
scale_factor=(1., 1., 1., 1.)) for _ in range(batch_size)
]
x = self.extract_feat(img)
outs = self.bbox_head(x, img_metas=dummy_img_metas)
bbox_list = self.bbox_head.get_bboxes(
*outs, dummy_img_metas, rescale=True)
return bbox_list
def get_pred(self):
"""
"""
img = self.inputs['image']
batch_size, _, height, width = img.shape
img_metas = [
dict(
batch_input_shape=(height, width),
img_shape=(height, width, 3),
scale_factor=self.inputs['scale_factor'][i])
for i in range(batch_size)
]
kptpred = self.simple_test(
self.inputs, img_metas=img_metas, rescale=True)
keypoints = kptpred[0][1][0]
bboxs = kptpred[0][0][0]
keypoints[..., 2] = bboxs[:, None, 4]
res_lst = [[keypoints, bboxs[:, 4]]]
outputs = {'keypoint': res_lst}
return outputs
def simple_test(self, inputs, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
inputs (list[paddle.Tensor]): List of multiple images.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox and keypoint results of each image
and classes. The outer list corresponds to each image.
The inner list corresponds to each class.
"""
batch_size = len(img_metas)
assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
f'mode is supported. Found batch_size {batch_size}.'
feat = self.extract_feat(inputs)
results_list = self.bbox_head.simple_test(
feat, img_metas, rescale=rescale)
bbox_kpt_results = [
self.bbox_kpt2result(det_bboxes, det_labels, det_kpts,
self.bbox_head.num_classes)
for det_bboxes, det_labels, det_kpts in results_list
]
return bbox_kpt_results
def bbox_kpt2result(self, bboxes, labels, kpts, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (paddle.Tensor | np.ndarray): shape (n, 5).
labels (paddle.Tensor | np.ndarray): shape (n, ).
kpts (paddle.Tensor | np.ndarray): shape (n, K, 3).
num_classes (int): class number, including background class.
Returns:
list(ndarray): bbox and keypoint results of each class.
"""
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)], \
[np.zeros((0, kpts.size(1), 3), dtype=np.float32)
for i in range(num_classes)]
else:
if isinstance(bboxes, paddle.Tensor):
bboxes = bboxes.numpy()
labels = labels.numpy()
kpts = kpts.numpy()
return [bboxes[labels == i, :] for i in range(num_classes)], \
[kpts[labels == i, :, :] for i in range(num_classes)]
...@@ -31,3 +31,5 @@ from .fcosr_assigner import * ...@@ -31,3 +31,5 @@ from .fcosr_assigner import *
from .rotated_task_aligned_assigner import * from .rotated_task_aligned_assigner import *
from .task_aligned_assigner_cr import * from .task_aligned_assigner_cr import *
from .uniform_assigner import * from .uniform_assigner import *
from .hungarian_assigner import *
from .pose_utils import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
import paddle
from ppdet.core.workspace import register
__all__ = ['PoseHungarianAssigner', 'PseudoSampler']
class AssignResult:
"""Stores assignments between predicted and truth boxes.
Attributes:
num_gts (int): the number of truth boxes considered when computing this
assignment
gt_inds (LongTensor): for each predicted box indicates the 1-based
index of the assigned truth box. 0 means unassigned and -1 means
ignore.
max_overlaps (FloatTensor): the iou between the predicted box and its
assigned truth box.
labels (None | LongTensor): If specified, for each predicted box
indicates the category label of the assigned truth box.
"""
def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.max_overlaps = max_overlaps
self.labels = labels
# Interface for possible user-defined properties
self._extra_properties = {}
@property
def num_preds(self):
"""int: the number of predictions in this assignment"""
return len(self.gt_inds)
def set_extra_property(self, key, value):
"""Set user-defined new property."""
assert key not in self.info
self._extra_properties[key] = value
def get_extra_property(self, key):
"""Get user-defined property."""
return self._extra_properties.get(key, None)
@property
def info(self):
"""dict: a dictionary of info about the object"""
basic_info = {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}
basic_info.update(self._extra_properties)
return basic_info
@register
class PoseHungarianAssigner:
"""Computes one-to-one matching between predictions and ground truth.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost, regression L1 cost and regression oks cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt.
- positive integer: positive sample, index (1-based) of assigned gt.
Args:
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
kpt_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0.
oks_weight (int | float, optional): The scale factor for regression
oks cost. Default 1.0.
"""
__inject__ = ['cls_cost', 'kpt_cost', 'oks_cost']
def __init__(self,
cls_cost='ClassificationCost',
kpt_cost='KptL1Cost',
oks_cost='OksCost'):
self.cls_cost = cls_cost
self.kpt_cost = kpt_cost
self.oks_cost = oks_cost
def assign(self,
cls_pred,
kpt_pred,
gt_labels,
gt_keypoints,
gt_areas,
img_meta,
eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
kpt_pred (Tensor): Predicted keypoints with normalized coordinates
(x_{i}, y_{i}), which are all in range [0, 1]. Shape
[num_query, K*2].
gt_labels (Tensor): Label of `gt_keypoints`, shape (num_gt,).
gt_keypoints (Tensor): Ground truth keypoints with unnormalized
coordinates [p^{1}_x, p^{1}_y, p^{1}_v, ..., \
p^{K}_x, p^{K}_y, p^{K}_v]. Shape [num_gt, K*3].
gt_areas (Tensor): Ground truth mask areas, shape (num_gt,).
img_meta (dict): Meta information for current image.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
num_gts, num_kpts = gt_keypoints.shape[0], kpt_pred.shape[0]
if not gt_keypoints.astype('bool').any():
num_gts = 0
# 1. assign -1 by default
assigned_gt_inds = paddle.full((num_kpts, ), -1, dtype="int64")
assigned_labels = paddle.full((num_kpts, ), -1, dtype="int64")
if num_gts == 0 or num_kpts == 0:
# No ground truth or keypoints, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
img_h, img_w, _ = img_meta['img_shape']
factor = paddle.to_tensor(
[img_w, img_h, img_w, img_h], dtype=gt_keypoints.dtype).reshape(
(1, -1))
# 2. compute the weighted costs
# classification cost
cls_cost = self.cls_cost(cls_pred, gt_labels)
# keypoint regression L1 cost
gt_keypoints_reshape = gt_keypoints.reshape((gt_keypoints.shape[0], -1,
3))
valid_kpt_flag = gt_keypoints_reshape[..., -1]
kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
2))
normalize_gt_keypoints = gt_keypoints_reshape[
..., :2] / factor[:, :2].unsqueeze(0)
kpt_cost = self.kpt_cost(kpt_pred_tmp, normalize_gt_keypoints,
valid_kpt_flag)
# keypoint OKS cost
kpt_pred_tmp = kpt_pred.clone().detach().reshape((kpt_pred.shape[0], -1,
2))
kpt_pred_tmp = kpt_pred_tmp * factor[:, :2].unsqueeze(0)
oks_cost = self.oks_cost(kpt_pred_tmp, gt_keypoints_reshape[..., :2],
valid_kpt_flag, gt_areas)
# weighted sum of above three costs
cost = cls_cost + kpt_cost + oks_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = paddle.to_tensor(matched_row_inds)
matched_col_inds = paddle.to_tensor(matched_col_inds)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds][
..., 0].astype("int64")
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
class SamplingResult:
"""Bbox sampling result.
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
if pos_inds.size > 0:
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = paddle.zeros(
gt_bboxes.shape, dtype=gt_bboxes.dtype).reshape((-1, 4))
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.reshape((-1, 4))
self.pos_gt_bboxes = paddle.index_select(
gt_bboxes,
self.pos_assigned_gt_inds.astype('int64'),
axis=0)
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""paddle.Tensor: concatenated positive and negative boxes"""
return paddle.concat([self.pos_bboxes, self.neg_bboxes])
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@register
class PseudoSampler:
"""A pseudo sampler that does not do sampling actually."""
def __init__(self, **kwargs):
pass
def _sample_pos(self, **kwargs):
"""Sample positive samples."""
raise NotImplementedError
def _sample_neg(self, **kwargs):
"""Sample negative samples."""
raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs):
"""Directly returns the positive and negative indices of samples.
Args:
assign_result (:obj:`AssignResult`): Assigned results
bboxes (paddle.Tensor): Bounding boxes
gt_bboxes (paddle.Tensor): Ground truth boxes
Returns:
:obj:`SamplingResult`: sampler results
"""
pos_inds = paddle.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1)
neg_inds = paddle.nonzero(
assign_result.gt_inds == 0, as_tuple=False).squeeze(-1)
gt_flags = paddle.zeros([bboxes.shape[0]], dtype='int32')
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register
__all__ = ['KptL1Cost', 'OksCost', 'ClassificationCost']
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
@register
class KptL1Cost(object):
"""KptL1Cost.
this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py
Args:
weight (int | float, optional): loss_weight.
"""
def __init__(self, weight=1.0):
self.weight = weight
def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag):
"""
Args:
kpt_pred (Tensor): Predicted keypoints with normalized coordinates
(x_{i}, y_{i}), which are all in range [0, 1]. Shape
[num_query, K, 2].
gt_keypoints (Tensor): Ground truth keypoints with normalized
coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2].
valid_kpt_flag (Tensor): valid flag of ground truth keypoints.
Shape [num_gt, K].
Returns:
paddle.Tensor: kpt_cost value with weight.
"""
kpt_cost = []
for i in range(len(gt_keypoints)):
if gt_keypoints[i].size == 0:
kpt_cost.append(kpt_pred.sum() * 0)
kpt_pred_tmp = kpt_pred.clone()
valid_flag = valid_kpt_flag[i] > 0
valid_flag_expand = valid_flag.unsqueeze(0).unsqueeze(-1).expand_as(
kpt_pred_tmp)
if not valid_flag_expand.all():
kpt_pred_tmp = masked_fill(kpt_pred_tmp, ~valid_flag_expand, 0)
cost = F.pairwise_distance(
kpt_pred_tmp.reshape((kpt_pred_tmp.shape[0], -1)),
gt_keypoints[i].reshape((-1, )).unsqueeze(0),
p=1,
keepdim=True)
avg_factor = paddle.clip(
valid_flag.astype('float32').sum() * 2, 1.0)
cost = cost / avg_factor
kpt_cost.append(cost)
kpt_cost = paddle.concat(kpt_cost, axis=1)
return kpt_cost * self.weight
@register
class OksCost(object):
"""OksCost.
this function based on: https://github.com/hikvision-research/opera/blob/main/opera/core/bbox/match_costs/match_cost.py
Args:
num_keypoints (int): number of keypoints
weight (int | float, optional): loss_weight.
"""
def __init__(self, num_keypoints=17, weight=1.0):
self.weight = weight
if num_keypoints == 17:
self.sigmas = np.array(
[
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89
],
dtype=np.float32) / 10.0
elif num_keypoints == 14:
self.sigmas = np.array(
[
.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89,
.89, .79, .79
],
dtype=np.float32) / 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def __call__(self, kpt_pred, gt_keypoints, valid_kpt_flag, gt_areas):
"""
Args:
kpt_pred (Tensor): Predicted keypoints with unnormalized
coordinates (x_{i}, y_{i}). Shape [num_query, K, 2].
gt_keypoints (Tensor): Ground truth keypoints with unnormalized
coordinates (x_{i}, y_{i}). Shape [num_gt, K, 2].
valid_kpt_flag (Tensor): valid flag of ground truth keypoints.
Shape [num_gt, K].
gt_areas (Tensor): Ground truth mask areas. Shape [num_gt,].
Returns:
paddle.Tensor: oks_cost value with weight.
"""
sigmas = paddle.to_tensor(self.sigmas)
variances = (sigmas * 2)**2
oks_cost = []
assert len(gt_keypoints) == len(gt_areas)
for i in range(len(gt_keypoints)):
if gt_keypoints[i].size == 0:
oks_cost.append(kpt_pred.sum() * 0)
squared_distance = \
(kpt_pred[:, :, 0] - gt_keypoints[i, :, 0].unsqueeze(0)) ** 2 + \
(kpt_pred[:, :, 1] - gt_keypoints[i, :, 1].unsqueeze(0)) ** 2
vis_flag = (valid_kpt_flag[i] > 0).astype('int')
vis_ind = vis_flag.nonzero(as_tuple=False)[:, 0]
num_vis_kpt = vis_ind.shape[0]
# assert num_vis_kpt > 0
if num_vis_kpt == 0:
oks_cost.append(paddle.zeros((squared_distance.shape[0], 1)))
continue
area = gt_areas[i]
squared_distance0 = squared_distance / (area * variances * 2)
squared_distance0 = paddle.index_select(
squared_distance0, vis_ind, axis=1)
squared_distance1 = paddle.exp(-squared_distance0).sum(axis=1,
keepdim=True)
oks = squared_distance1 / num_vis_kpt
# The 1 is a constant that doesn't change the matching, so omitted.
oks_cost.append(-oks)
oks_cost = paddle.concat(oks_cost, axis=1)
return oks_cost * self.weight
@register
class ClassificationCost:
"""ClsSoftmaxCost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
paddle.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@register
class FocalLossCost:
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
binary_input (bool, optional): Whether the input is binary,
default False.
"""
def __init__(self,
weight=1.,
alpha=0.25,
gamma=2,
eps=1e-12,
binary_input=False):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
self.binary_input = binary_input
def _focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
(num_query, num_class).
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
paddle.Tensor: cls_cost value with weight
"""
if gt_labels.size == 0:
return cls_pred.sum() * 0
cls_pred = F.sigmoid(cls_pred)
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = paddle.index_select(
pos_cost, gt_labels, axis=1) - paddle.index_select(
neg_cost, gt_labels, axis=1)
return cls_cost * self.weight
def _mask_focal_loss_cost(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits
in shape (num_query, d1, ..., dn), dtype=paddle.float32.
gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
dtype=paddle.long. Labels should be binary.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1)
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
cls_pred = F.sigmoid(cls_pred)
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = paddle.einsum('nc,mc->nm', pos_cost, gt_labels) + \
paddle.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
return cls_cost / n * self.weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classfication logits.
gt_labels (Tensor)): Labels.
Returns:
Tensor: Focal cost matrix with weight in shape\
(num_query, num_gt).
"""
if self.binary_input:
return self._mask_focal_loss_cost(cls_pred, gt_labels)
else:
return self._focal_loss_cost(cls_pred, gt_labels)
...@@ -285,36 +285,6 @@ class BottleNeck(nn.Layer): ...@@ -285,36 +285,6 @@ class BottleNeck(nn.Layer):
# ResNeXt # ResNeXt
width = int(ch_out * (base_width / 64.)) * groups width = int(ch_out * (base_width / 64.)) * groups
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2a = ConvNormLayer( self.branch2a = ConvNormLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=width, ch_out=width,
...@@ -351,6 +321,36 @@ class BottleNeck(nn.Layer): ...@@ -351,6 +321,36 @@ class BottleNeck(nn.Layer):
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
lr=lr) lr=lr)
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.std_senet = std_senet self.std_senet = std_senet
if self.std_senet: if self.std_senet:
self.se = SELayer(ch_out * self.expansion) self.se = SELayer(ch_out * self.expansion)
......
...@@ -284,9 +284,9 @@ class RelativePositionBias(nn.Layer): ...@@ -284,9 +284,9 @@ class RelativePositionBias(nn.Layer):
def forward(self): def forward(self):
relative_position_bias = \ relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww return relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww
......
...@@ -67,3 +67,4 @@ from .yolof_head import * ...@@ -67,3 +67,4 @@ from .yolof_head import *
from .ppyoloe_contrast_head import * from .ppyoloe_contrast_head import *
from .centertrack_head import * from .centertrack_head import *
from .sparse_roi_head import * from .sparse_roi_head import *
from .petr_head import *
此差异已折叠。
...@@ -1135,7 +1135,7 @@ def _convert_attention_mask(attn_mask, dtype): ...@@ -1135,7 +1135,7 @@ def _convert_attention_mask(attn_mask, dtype):
""" """
return nn.layer.transformer._convert_attention_mask(attn_mask, dtype) return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
@register
class MultiHeadAttention(nn.Layer): class MultiHeadAttention(nn.Layer):
""" """
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
......
...@@ -21,7 +21,7 @@ import paddle.nn.functional as F ...@@ -21,7 +21,7 @@ import paddle.nn.functional as F
import paddle.nn as nn import paddle.nn as nn
from ppdet.core.workspace import register from ppdet.core.workspace import register
__all__ = ['FocalLoss'] __all__ = ['FocalLoss', 'Weighted_FocalLoss']
@register @register
class FocalLoss(nn.Layer): class FocalLoss(nn.Layer):
...@@ -59,3 +59,80 @@ class FocalLoss(nn.Layer): ...@@ -59,3 +59,80 @@ class FocalLoss(nn.Layer):
pred, target, alpha=self.alpha, gamma=self.gamma, pred, target, alpha=self.alpha, gamma=self.gamma,
reduction=reduction) reduction=reduction)
return loss * self.loss_weight return loss * self.loss_weight
@register
class Weighted_FocalLoss(FocalLoss):
"""A wrapper around paddle.nn.functional.sigmoid_focal_loss.
Args:
use_sigmoid (bool): currently only support use_sigmoid=True
alpha (float): parameter alpha in Focal Loss
gamma (float): parameter gamma in Focal Loss
loss_weight (float): final loss will be multiplied by this
"""
def __init__(self,
use_sigmoid=True,
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
reduction="mean"):
super(FocalLoss, self).__init__()
assert use_sigmoid == True, \
'Focal Loss only supports sigmoid at the moment'
self.use_sigmoid = use_sigmoid
self.alpha = alpha
self.gamma = gamma
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None):
"""forward function.
Args:
pred (Tensor): logits of class prediction, of shape (N, num_classes)
target (Tensor): target class label, of shape (N, )
reduction (str): the way to reduce loss, one of (none, sum, mean)
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
num_classes = pred.shape[1]
target = F.one_hot(target, num_classes + 1).astype(pred.dtype)
target = target[:, :-1].detach()
loss = F.sigmoid_focal_loss(
pred, target, alpha=self.alpha, gamma=self.gamma,
reduction='none')
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss * self.loss_weight
...@@ -18,12 +18,13 @@ from __future__ import print_function ...@@ -18,12 +18,13 @@ from __future__ import print_function
from itertools import cycle, islice from itertools import cycle, islice
from collections import abc from collections import abc
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
__all__ = ['HrHRNetLoss', 'KeyPointMSELoss'] __all__ = ['HrHRNetLoss', 'KeyPointMSELoss', 'OKSLoss', 'CenterFocalLoss', 'L1Loss']
@register @register
...@@ -226,3 +227,406 @@ def recursive_sum(inputs): ...@@ -226,3 +227,406 @@ def recursive_sum(inputs):
if isinstance(inputs, abc.Sequence): if isinstance(inputs, abc.Sequence):
return sum([recursive_sum(x) for x in inputs]) return sum([recursive_sum(x) for x in inputs])
return inputs return inputs
def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas):
if not kpt_gts.astype('bool').any():
return kpt_preds.sum()*0
sigmas = paddle.to_tensor(sigmas, dtype=kpt_preds.dtype)
variances = (sigmas * 2)**2
assert kpt_preds.shape[0] == kpt_gts.shape[0]
kpt_preds = kpt_preds.reshape((-1, kpt_preds.shape[-1] // 2, 2))
kpt_gts = kpt_gts.reshape((-1, kpt_gts.shape[-1] // 2, 2))
squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \
(kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2
assert (kpt_valids.sum(-1) > 0).all()
squared_distance0 = squared_distance / (
kpt_areas[:, None] * variances[None, :] * 2)
squared_distance1 = paddle.exp(-squared_distance0)
squared_distance1 = squared_distance1 * kpt_valids
oks = squared_distance1.sum(axis=1) / kpt_valids.sum(axis=1)
return oks
def oks_loss(pred,
target,
weight,
valid=None,
area=None,
linear=False,
sigmas=None,
eps=1e-6,
avg_factor=None,
reduction=None):
"""Oks loss.
Computing the oks loss between a set of predicted poses and target poses.
The loss is calculated as negative log of oks.
Args:
pred (Tensor): Predicted poses of format (x1, y1, x2, y2, ...),
shape (n, K*2).
target (Tensor): Corresponding gt poses, shape (n, K*2).
linear (bool, optional): If True, use linear scale of loss instead of
log scale. Default: False.
eps (float): Eps to avoid log(0).
Returns:
Tensor: Loss tensor.
"""
oks = oks_overlaps(pred, target, valid, area, sigmas).clip(min=eps)
if linear:
loss = 1 - oks
else:
loss = -oks.log()
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class OKSLoss(nn.Layer):
"""OKSLoss.
Computing the oks loss between a set of predicted poses and target poses.
Args:
linear (bool): If True, use linear scale of loss instead of log scale.
Default: False.
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
"""
def __init__(self,
linear=False,
num_keypoints=17,
eps=1e-6,
reduction='mean',
loss_weight=1.0):
super(OKSLoss, self).__init__()
self.linear = linear
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
if num_keypoints == 17:
self.sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07,
1.07, .87, .87, .89, .89
], dtype=np.float32) / 10.0
elif num_keypoints == 14:
self.sigmas = np.array([
.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89,
.79, .79
]) / 10.0
else:
raise ValueError(f'Unsupported keypoints number {num_keypoints}')
def forward(self,
pred,
target,
valid,
area,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
valid (Tensor): The visible flag of the target pose.
area (Tensor): The area of the target pose.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if (weight is not None) and (not paddle.any(weight > 0)) and (
reduction != 'none'):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# iou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * oks_loss(
pred,
target,
weight,
valid=valid,
area=area,
linear=self.linear,
sigmas=self.sigmas,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
def center_focal_loss(pred, gt, weight=None, mask=None, avg_factor=None, reduction=None):
"""Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory.
Args:
pred (Tensor): The prediction with shape [bs, c, h, w].
gt (Tensor): The learning target of the prediction in gaussian
distribution, with shape [bs, c, h, w].
mask (Tensor): The valid mask. Defaults to None.
"""
if not gt.astype('bool').any():
return pred.sum()*0
pos_inds = gt.equal(1).astype('float32')
if mask is None:
neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32')
else:
neg_inds = gt.less_than(paddle.to_tensor([1], dtype='float32')).astype('float32') * mask.equal(0).astype('float32')
neg_weights = paddle.pow(1 - gt, 4)
loss = 0
pos_loss = paddle.log(pred) * paddle.pow(1 - pred, 2) * pos_inds
neg_loss = paddle.log(1 - pred) * paddle.pow(pred, 2) * neg_weights * \
neg_inds
num_pos = pos_inds.astype('float32').sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class CenterFocalLoss(nn.Layer):
"""CenterFocalLoss is a variant of focal loss.
More details can be found in the `paper
<https://arxiv.org/abs/1808.01244>`_
Args:
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self,
reduction='none',
loss_weight=1.0):
super(CenterFocalLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
mask=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction in gaussian
distribution.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
mask (Tensor): The valid mask. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_reg = self.loss_weight * center_focal_loss(
pred,
target,
weight,
mask=mask,
reduction=reduction,
avg_factor=avg_factor)
return loss_reg
def l1_loss(pred, target, weight=None, reduction='mean', avg_factor=None):
"""L1 loss.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
Returns:
Tensor: Calculated loss
"""
if not target.astype('bool').any():
return pred.sum() * 0
assert pred.shape == target.shape
loss = paddle.abs(pred - target)
if weight is not None:
if weight.shape != loss.shape:
if weight.shape[0] == loss.shape[0]:
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.reshape((-1, 1))
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.reshape((loss.shape[0], -1))
assert weight.ndim == loss.ndim
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = 1e-10
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@register
@serializable
class L1Loss(nn.Layer):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): The prediction.
target (Tensor): The learning target of the prediction.
weight (Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
...@@ -36,3 +36,4 @@ from .es_pan import * ...@@ -36,3 +36,4 @@ from .es_pan import *
from .lc_pan import * from .lc_pan import *
from .custom_pan import * from .custom_pan import *
from .dilated_encoder import * from .dilated_encoder import *
from .channel_mapper import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on mmdet: git@github.com:open-mmlab/mmdetection.git
"""
import paddle.nn as nn
from ppdet.core.workspace import register, serializable
from ..backbones.hrnet import ConvNormLayer
from ..shape_spec import ShapeSpec
from ..initializer import xavier_uniform_, constant_
__all__ = ['ChannelMapper']
@register
@serializable
class ChannelMapper(nn.Layer):
"""Channel Mapper to reduce/increase channels of backbone features.
This is used to reduce/increase channels of backbone features.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
kernel_size (int, optional): kernel_size for reducing channels (used
at each scale). Default: 3.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU').
num_outs (int, optional): Number of output feature maps. There
would be extra_convs when num_outs larger than the length
of in_channels.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
norm_type="gn",
norm_groups=32,
act='relu',
num_outs=None,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(ChannelMapper, self).__init__()
assert isinstance(in_channels, list)
self.extra_convs = None
if num_outs is None:
num_outs = len(in_channels)
self.convs = nn.LayerList()
for in_channel in in_channels:
self.convs.append(
ConvNormLayer(
ch_in=in_channel,
ch_out=out_channels,
filter_size=kernel_size,
norm_type='gn',
norm_groups=32,
act=act))
if num_outs > len(in_channels):
self.extra_convs = nn.LayerList()
for i in range(len(in_channels), num_outs):
if i == len(in_channels):
in_channel = in_channels[-1]
else:
in_channel = out_channels
self.extra_convs.append(
ConvNormLayer(
ch_in=in_channel,
ch_out=out_channels,
filter_size=3,
stride=2,
norm_type='gn',
norm_groups=32,
act=act))
self.init_weights()
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.convs)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
if self.extra_convs:
for i in range(len(self.extra_convs)):
if i == 0:
outs.append(self.extra_convs[0](inputs[-1]))
else:
outs.append(self.extra_convs[i](outs[-1]))
return tuple(outs)
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channel, stride=1. / s)
for s in self.spatial_scales
]
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.rank() > 1:
xavier_uniform_(p)
if hasattr(p, 'bias') and p.bias is not None:
constant_(p.bais)
...@@ -25,3 +25,4 @@ from .matchers import * ...@@ -25,3 +25,4 @@ from .matchers import *
from .position_encoding import * from .position_encoding import *
from .deformable_transformer import * from .deformable_transformer import *
from .dino_transformer import * from .dino_transformer import *
from .petr_transformer import *
此差异已折叠。
...@@ -238,7 +238,7 @@ def draw_pose(image, ...@@ -238,7 +238,7 @@ def draw_pose(image,
'for example: `pip install matplotlib`.') 'for example: `pip install matplotlib`.')
raise e raise e
skeletons = np.array([item['keypoints'] for item in results]) skeletons = np.array([item['keypoints'] for item in results]).reshape((-1, 51))
kpt_nums = 17 kpt_nums = 17
if len(skeletons) > 0: if len(skeletons) > 0:
kpt_nums = int(skeletons.shape[1] / 3) kpt_nums = int(skeletons.shape[1] / 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册