未验证 提交 78a6629b 编写于 作者: P pk_hk 提交者: GitHub

[smalldet] add tal_cr and custompan with transfomer, add reg_max pred (#7253)

* [smdet] add tal_cr and custompan with transfomer, add box_distribution to pred reg_max, test=document_fix

* add tal_cr
上级 229350a2
......@@ -13,11 +13,14 @@ PaddleDetection团队提供了针对VisDrone-DET小目标数航拍场景的基
|:---------|:------:|:------:| :----: | :------:| :------: | :------:| :----: | :------:|
|PP-YOLOE-s| 23.5 | 39.9 | 19.4 | 33.6 | 23.68 | 40.66 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_s_80e_visdrone.yml) |
|PP-YOLOE-P2-Alpha-s| 24.4 | 41.6 | 20.1 | 34.7 | 24.55 | 42.19 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_p2_alpha_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_s_p2_alpha_80e_visdrone.yml) |
|PP_YOLOE_plus_new_s| 25.1 | 42.8 | 20.7 | 36.2 | 25.16 | 43.86 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_new_crn_s_80e_visdrone.pdparams) | [配置文件](./ppyoloe_plus_new_crn_s_80e_visdrone.yml) |
|PP-YOLOE-l| 29.2 | 47.3 | 23.5 | 39.1 | 28.00 | 46.20 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_80e_visdrone.yml) |
|PP-YOLOE-P2-Alpha-l| 30.1 | 48.9 | 24.3 | 40.8 | 28.47 | 48.16 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_p2_alpha_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_p2_alpha_80e_visdrone.yml) |
|PP_YOLOE_plus_new_l| 31.9 | 52.1 | 25.6 | 43.5 | 30.25 | 51.18 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_new_crn_l_80e_visdrone.pdparams) | [配置文件](./ppyoloe_plus_new_crn_l_80e_visdrone.yml) |
|PP-YOLOE-Alpha-largesize-l| 41.9 | 65.0 | 32.3 | 53.0 | 37.13 | 61.15 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_alpha_largesize_80e_visdrone.yml) |
|PP-YOLOE-P2-Alpha-largesize-l| 41.3 | 64.5 | 32.4 | 53.1 | 37.49 | 51.54 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_crn_l_p2_alpha_largesize_80e_visdrone.yml) |
|PP-YOLOE-plus-largesize-l | 43.3 | 66.7 | 33.5 | 54.7 | 38.24 | 62.76 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_plus_crn_l_largesize_80e_visdrone.yml) |
|PP-YOLOE-plus_new-largesize_l | 42.7 | 65.9 | 33.6 | 55.1 | 38.4 | 63.07 | [下载链接](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_new_crn_l_largesize_80e_visdrone.pdparams) | [配置文件](./ppyoloe_plus_new_crn_l_largesize_80e_visdrone.yml) |
## 原图评估和拼图评估对比:
......
_BASE_: [
'../datasets/visdrone_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_80e.yml',
'../ppyoloe/_base_/ppyoloe_plus_crn.yml',
'../ppyoloe/_base_/ppyoloe_plus_reader.yml'
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_new_crn_l_80e_visdrone/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco.pdparams
depth_mult: 1.0
width_mult: 1.0
TrainReader:
batch_size: 8
EvalReader:
batch_size: 1
TestReader:
batch_size: 1
fuse_normalize: True
epoch: 80
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 1
CustomCSPPAN:
num_layers: 4
use_trans: True
PPYOLOEHead:
reg_range: [-2,8]
static_assigner_epoch: -1
static_assigner:
name: ATSSAssigner
topk: 9
sm_use: True
assigner:
name: TaskAlignedAssigner_CR
center_radius: 1
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 10000
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.6
_BASE_: [
'ppyoloe_plus_new_crn_l_80e_visdrone.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_new_large_crn_l_80e_visdrone/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams
PPYOLOEHead:
reg_range: [-2,20]
static_assigner_epoch: -1
LearningRate:
base_lr: 0.00125
worker_num: 2
eval_height: &eval_height 1920
eval_width: &eval_width 1920
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792, 1856, 1920], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 1
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
fuse_normalize: True
_BASE_: [
'../datasets/visdrone_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_80e.yml',
'../ppyoloe/_base_/ppyoloe_plus_crn.yml',
'../ppyoloe/_base_/ppyoloe_plus_reader.yml'
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_new_crn_s_80e_visdrone/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_s_80e_coco.pdparams
depth_mult: 0.33
width_mult: 0.50
TrainReader:
batch_size: 8
EvalReader:
batch_size: 1
TestReader:
batch_size: 1
fuse_normalize: True
epoch: 80
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 1
CustomCSPPAN:
num_layers: 4
use_trans: True
PPYOLOEHead:
reg_range: [-2,8]
static_assigner_epoch: -1
static_assigner:
name: ATSSAssigner
topk: 9
sm_use: True
assigner:
name: TaskAlignedAssigner_CR
center_radius: 1
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 10000
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.6
......@@ -19,6 +19,7 @@ from . import simota_assigner
from . import max_iou_assigner
from . import fcosr_assigner
from . import rotated_task_aligned_assigner
from . import task_aligned_assigner_cr
from .utils import *
from .task_aligned_assigner import *
......@@ -27,3 +28,4 @@ from .simota_assigner import *
from .max_iou_assigner import *
from .fcosr_assigner import *
from .rotated_task_aligned_assigner import *
from .task_aligned_assigner_cr import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -41,12 +41,14 @@ class ATSSAssigner(nn.Layer):
topk=9,
num_classes=80,
force_gt_matching=False,
eps=1e-9):
eps=1e-9,
sm_use=False):
super(ATSSAssigner, self).__init__()
self.topk = topk
self.num_classes = num_classes
self.force_gt_matching = force_gt_matching
self.eps = eps
self.sm_use = sm_use
def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
pad_gt_mask):
......@@ -154,6 +156,10 @@ class ATSSAssigner(nn.Layer):
paddle.zeros_like(is_in_topk))
# 6. check the positive sample's center in gt, [B, n, L]
if self.sm_use:
is_in_gts = check_points_inside_bboxes(
anchor_centers, gt_bboxes, sm_use=True)
else:
is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
# select positive sample, [B, n, L]
......@@ -165,6 +171,9 @@ class ATSSAssigner(nn.Layer):
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
[1, num_max_boxes, 1])
if self.sm_use:
is_max_iou = compute_max_iou_anchor(ious * mask_positive)
else:
is_max_iou = compute_max_iou_anchor(ious)
mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
mask_positive)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import batch_iou_similarity
from .utils import (gather_topk_anchors, check_points_inside_bboxes,
compute_max_iou_anchor)
__all__ = ['TaskAlignedAssigner_CR']
@register
class TaskAlignedAssigner_CR(nn.Layer):
"""TOOD: Task-aligned One-stage Object Detection with Center R
"""
def __init__(self,
topk=13,
alpha=1.0,
beta=6.0,
center_radius=None,
eps=1e-9):
super(TaskAlignedAssigner_CR, self).__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
self.center_radius = center_radius
self.eps = eps
@paddle.no_grad()
def forward(self,
pred_scores,
pred_bboxes,
anchor_points,
stride_tensor,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index,
gt_scores=None):
r"""This code is based on
https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free detector
only can predict positive distance)
4. if an anchor box is assigned to multiple gts, the one with the
highest iou will be selected.
Args:
pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4)
anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format
stride_tensor (Tensor, float32): stride of feature map, shape(L, 1)
gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
bg_index (int): background index
gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
"""
assert pred_scores.ndim == pred_bboxes.ndim
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
batch_size, num_anchors, num_classes = pred_scores.shape
_, num_max_boxes, _ = gt_bboxes.shape
# negative batch
if num_max_boxes == 0:
assigned_labels = paddle.full(
[batch_size, num_anchors], bg_index, dtype='int32')
assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
assigned_scores = paddle.zeros(
[batch_size, num_anchors, num_classes])
return assigned_labels, assigned_bboxes, assigned_scores
# compute iou between gt and pred bbox, [B, n, L]
ious = batch_iou_similarity(gt_bboxes, pred_bboxes)
# gather pred bboxes class score
pred_scores = pred_scores.transpose([0, 2, 1])
batch_ind = paddle.arange(
end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
gt_labels_ind = paddle.stack(
[batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)],
axis=-1)
bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
# compute alignment metrics, [B, n, L]
alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
self.beta) * pad_gt_mask
# select positive sample, [B, n, L]
if self.center_radius is None:
# check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(
anchor_points, gt_bboxes, sm_use=True)
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
mask_positive = gather_topk_anchors(
alignment_metrics, self.topk, topk_mask=pad_gt_mask) * is_in_gts
else:
is_in_gts, is_in_center = check_points_inside_bboxes(
anchor_points,
gt_bboxes,
stride_tensor * self.center_radius,
sm_use=True)
is_in_gts *= pad_gt_mask
is_in_center *= pad_gt_mask
candidate_metrics = paddle.where(
is_in_gts.sum(-1, keepdim=True) == 0,
alignment_metrics + is_in_center,
alignment_metrics)
mask_positive = gather_topk_anchors(
candidate_metrics, self.topk,
topk_mask=pad_gt_mask) * paddle.cast((is_in_center > 0) |
(is_in_gts > 0), 'float32')
# if an anchor box is assigned to multiple gts,
# the one with the highest iou will be selected, [B, n, L]
mask_positive_sum = mask_positive.sum(axis=-2)
if mask_positive_sum.max() > 1:
mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
[1, num_max_boxes, 1])
is_max_iou = compute_max_iou_anchor(ious * mask_positive)
mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
mask_positive)
mask_positive_sum = mask_positive.sum(axis=-2)
assigned_gt_index = mask_positive.argmax(axis=-2)
# assigned target
assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
assigned_labels = paddle.gather(
gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
assigned_labels = paddle.where(
mask_positive_sum > 0, assigned_labels,
paddle.full_like(assigned_labels, bg_index))
assigned_bboxes = paddle.gather(
gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
ind = list(range(num_classes + 1))
ind.remove(bg_index)
assigned_scores = paddle.index_select(
assigned_scores, paddle.to_tensor(ind), axis=-1)
# rescale alignment metrics
alignment_metrics *= mask_positive
max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
max_ious_per_instance = (ious * mask_positive).max(axis=-1,
keepdim=True)
alignment_metrics = alignment_metrics / (
max_metrics_per_instance + self.eps) * max_ious_per_instance
alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
assigned_scores = assigned_scores * alignment_metrics
return assigned_labels, assigned_bboxes, assigned_scores
......@@ -108,7 +108,8 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
def check_points_inside_bboxes(points,
bboxes,
center_radius_tensor=None,
eps=1e-9):
eps=1e-9,
sm_use=False):
r"""
Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
......@@ -139,6 +140,10 @@ def check_points_inside_bboxes(points,
b = (cy + center_radius_tensor) - y
delta_ltrb_c = paddle.concat([l, t, r, b], axis=-1)
is_in_center = (delta_ltrb_c.min(axis=-1) > eps)
if sm_use:
return is_in_bboxes.astype(bboxes.dtype), is_in_center.astype(
bboxes.dtype)
else:
return (paddle.logical_and(is_in_bboxes, is_in_center),
paddle.logical_or(is_in_bboxes, is_in_center))
......
......@@ -60,6 +60,7 @@ class PPYOLOEHead(nn.Layer):
grid_cell_scale=5.0,
grid_cell_offset=0.5,
reg_max=16,
reg_range=False,
static_assigner_epoch=4,
use_varifocal_loss=True,
static_assigner='ATSSAssigner',
......@@ -82,7 +83,12 @@ class PPYOLOEHead(nn.Layer):
self.fpn_strides = fpn_strides
self.grid_cell_scale = grid_cell_scale
self.grid_cell_offset = grid_cell_offset
self.reg_max = reg_max
if reg_range:
self.sm_use = True
self.reg_range = reg_range
else:
self.reg_range = (0, reg_max + 1)
self.reg_channels = self.reg_range[1] - self.reg_range[0]
self.iou_loss = GIoULoss()
self.loss_weight = loss_weight
self.use_varifocal_loss = use_varifocal_loss
......@@ -116,9 +122,9 @@ class PPYOLOEHead(nn.Layer):
in_c, self.num_classes, 3, padding=1))
self.pred_reg.append(
nn.Conv2D(
in_c, 4 * (self.reg_max + 1), 3, padding=1))
in_c, 4 * self.reg_channels, 3, padding=1))
# projection conv
self.proj_conv = nn.Conv2D(self.reg_max + 1, 1, 1, bias_attr=False)
self.proj_conv = nn.Conv2D(self.reg_channels, 1, 1, bias_attr=False)
self.proj_conv.skip_quant = True
self._init_weights()
......@@ -134,8 +140,9 @@ class PPYOLOEHead(nn.Layer):
constant_(reg_.weight)
constant_(reg_.bias, 1.0)
proj = paddle.linspace(0, self.reg_max, self.reg_max + 1).reshape(
[1, self.reg_max + 1, 1, 1])
proj = paddle.linspace(self.reg_range[0], self.reg_range[1] - 1,
self.reg_channels).reshape(
[1, self.reg_channels, 1, 1])
self.proj_conv.weight.set_value(proj)
self.proj_conv.weight.stop_gradient = True
if self.eval_size:
......@@ -202,8 +209,8 @@ class PPYOLOEHead(nn.Layer):
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
[0, 2, 3, 1])
reg_dist = reg_dist.reshape(
[-1, 4, self.reg_channels, l]).transpose([0, 2, 3, 1])
if self.use_shared_conv:
reg_dist = self.proj_conv(F.softmax(
reg_dist, axis=1)).squeeze(1)
......@@ -251,7 +258,7 @@ class PPYOLOEHead(nn.Layer):
def _bbox_decode(self, anchor_points, pred_dist):
_, l, _ = get_static_shape(pred_dist)
pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_max + 1]))
pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_channels]))
pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1)
return batch_distance2bbox(anchor_points, pred_dist)
......@@ -259,17 +266,20 @@ class PPYOLOEHead(nn.Layer):
x1y1, x2y2 = paddle.split(bbox, 2, -1)
lt = points - x1y1
rb = x2y2 - points
return paddle.concat([lt, rb], -1).clip(0, self.reg_max - 0.01)
return paddle.concat([lt, rb], -1).clip(self.reg_range[0],
self.reg_range[1] - 1 - 0.01)
def _df_loss(self, pred_dist, target):
target_left = paddle.cast(target, 'int64')
def _df_loss(self, pred_dist, target, lower_bound=0):
target_left = paddle.cast(target.floor(), 'int64')
target_right = target_left + 1
weight_left = target_right.astype('float32') - target
weight_right = 1 - weight_left
loss_left = F.cross_entropy(
pred_dist, target_left, reduction='none') * weight_left
pred_dist, target_left - lower_bound,
reduction='none') * weight_left
loss_right = F.cross_entropy(
pred_dist, target_right, reduction='none') * weight_right
pred_dist, target_right - lower_bound,
reduction='none') * weight_right
return (loss_left + loss_right).mean(-1, keepdim=True)
def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels,
......@@ -295,14 +305,14 @@ class PPYOLOEHead(nn.Layer):
loss_iou = loss_iou.sum() / assigned_scores_sum
dist_mask = mask_positive.unsqueeze(-1).tile(
[1, 1, (self.reg_max + 1) * 4])
[1, 1, self.reg_channels * 4])
pred_dist_pos = paddle.masked_select(
pred_dist, dist_mask).reshape([-1, 4, self.reg_max + 1])
pred_dist, dist_mask).reshape([-1, 4, self.reg_channels])
assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
assigned_ltrb_pos = paddle.masked_select(
assigned_ltrb, bbox_mask).reshape([-1, 4])
loss_dfl = self._df_loss(pred_dist_pos,
assigned_ltrb_pos) * bbox_weight
loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos,
self.reg_range[0]) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum
else:
loss_l1 = paddle.zeros([1])
......@@ -332,6 +342,18 @@ class PPYOLOEHead(nn.Layer):
bg_index=self.num_classes,
pred_bboxes=pred_bboxes.detach() * stride_tensor)
alpha_l = 0.25
else:
if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
stride_tensor,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,19 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import copy
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock
from ppdet.modeling.layers import DropBlock, MultiHeadAttention
from ppdet.modeling.ops import get_act_fn
from ..backbones.cspresnet import ConvBNLayer, BasicBlock
from ..shape_spec import ShapeSpec
from ..initializer import linear_init_
__all__ = ['CustomCSPPAN']
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class SPP(nn.Layer):
def __init__(self,
ch_in,
......@@ -99,6 +106,81 @@ class CSPStage(nn.Layer):
return y
class TransformerEncoderLayer(nn.Layer):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False):
super(TransformerEncoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.self_attn = MultiHeadAttention(d_model, nhead, attn_dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(act_dropout, mode="upscale_in_train")
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
@staticmethod
def with_pos_embed(tensor, pos_embed):
return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None):
residual = src
if self.normalize_before:
src = self.norm1(src)
q = k = self.with_pos_embed(src, pos_embed)
src = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerEncoder(nn.Layer):
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, src_mask=None, pos_embed=None):
output = src
for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
if self.norm is not None:
output = self.norm(output)
return output
@register
@serializable
class CustomCSPPAN(nn.Layer):
......@@ -121,7 +203,16 @@ class CustomCSPPAN(nn.Layer):
width_mult=1.0,
depth_mult=1.0,
use_alpha=False,
trt=False):
trt=False,
dim_feedforward=2048,
dropout=0.1,
activation='gelu',
nhead=4,
num_layers=4,
attn_dropout=None,
act_dropout=None,
normalize_before=False,
use_trans=False):
super(CustomCSPPAN, self).__init__()
out_channels = [max(round(c * width_mult), 1) for c in out_channels]
......@@ -132,7 +223,19 @@ class CustomCSPPAN(nn.Layer):
self.num_blocks = len(in_channels)
self.data_format = data_format
self._out_channels = out_channels
self.hidden_dim = in_channels[-1]
in_channels = in_channels[::-1]
self.nhead = nhead
self.num_layers = num_layers
self.use_trans = use_trans
if use_trans:
encoder_layer = TransformerEncoderLayer(
self.hidden_dim, nhead, dim_feedforward, dropout, activation,
attn_dropout, act_dropout, normalize_before)
encoder_norm = nn.LayerNorm(
self.hidden_dim) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, self.num_layers,
encoder_norm)
fpn_stages = []
fpn_routes = []
for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
......@@ -204,7 +307,45 @@ class CustomCSPPAN(nn.Layer):
self.pan_stages = nn.LayerList(pan_stages[::-1])
self.pan_routes = nn.LayerList(pan_routes[::-1])
def build_2d_sincos_position_embedding(
self,
w,
h,
embed_dim=1024,
temperature=10000., ):
grid_w = paddle.arange(int(w), dtype=paddle.float32)
grid_h = paddle.arange(int(h), dtype=paddle.float32)
grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = grid_w.flatten()[..., None] @omega[None]
out_h = grid_h.flatten()[..., None] @omega[None]
pos_emb = paddle.concat(
[
paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
paddle.cos(out_h)
],
axis=1)[None, :, :]
return pos_emb
def forward(self, blocks, for_mot=False):
if self.use_trans:
last_feat = blocks[-1]
n, c, h, w = last_feat.shape
# flatten [B, C, H, W] to [B, HxW, C]
src_flatten = last_feat.flatten(2).transpose([0, 2, 1])
pos_embed = self.build_2d_sincos_position_embedding(
w=w, h=h, embed_dim=self.hidden_dim)
memory = self.encoder(src_flatten, pos_embed=pos_embed)
last_feat_encode = memory.transpose([0, 2, 1]).reshape([n, c, h, w])
blocks[-1] = last_feat_encode
blocks = blocks[::-1]
fpn_feats = []
......
......@@ -17,6 +17,7 @@ import json
import numpy as np
import argparse
from pycocotools.coco import COCO
from tqdm import tqdm
def median(data):
......@@ -44,7 +45,7 @@ def draw_distribution(width, height, out_path):
plt.show()
def get_ratio_infos(jsonfile, out_img):
def get_ratio_infos(jsonfile, out_img, eval_size, small_stride):
coco = COCO(annotation_file=jsonfile)
allannjson = json.load(open(jsonfile, 'r'))
be_im_id = allannjson['annotations'][0]['image_id']
......@@ -52,7 +53,7 @@ def get_ratio_infos(jsonfile, out_img):
be_im_h = []
ratio_w = []
ratio_h = []
for i, ann in enumerate(allannjson['annotations']):
for ann in tqdm(allannjson['annotations']):
if ann['iscrowd']:
continue
x0, y0, w, h = ann['bbox'][:]
......@@ -82,8 +83,23 @@ def get_ratio_infos(jsonfile, out_img):
ratio_h.append(dis_h)
mid_w = median(ratio_w)
mid_h = median(ratio_h)
reg_ratio = []
ratio_all = ratio_h + ratio_w
for r in ratio_all:
if r < 0.2:
reg_ratio.append(r)
elif r < 0.4:
reg_ratio.append(r/2)
else:
reg_ratio.append(r/4)
reg_ratio = sorted(reg_ratio)
max_ratio = reg_ratio[int(0.95*len(reg_ratio))]
reg_max = round(max_ratio*eval_size/small_stride)
ratio_w = [i * 1000 for i in ratio_w]
ratio_h = [i * 1000 for i in ratio_h]
print(f'suggested of reg_range[1] is {reg_max+1}' )
print(f'Median of ratio_w is {mid_w}')
print(f'Median of ratio_h is {mid_h}')
print('all_img with box: ', len(ratio_h))
......@@ -95,6 +111,10 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--json_path', type=str, default=None, help="Dataset json path.")
parser.add_argument(
'--eval_size', type=int, default=640, help="eval size.")
parser.add_argument(
'--small_stride', type=int, default=8, help="smallest stride.")
parser.add_argument(
'--out_img',
type=str,
......@@ -102,7 +122,7 @@ def main():
help="Name of distibution img.")
args = parser.parse_args()
get_ratio_infos(args.json_path, args.out_img)
get_ratio_infos(args.json_path, args.out_img, args.eval_size, args.small_stride)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册