提交 fca67b8e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: qingqing01

add cascade_cls_aware models (#19)

* add softnms, nonlocal, resnet200_vd_backbone
* add CBNet
* update model zoo
上级 75d4c0f6
architecture: CascadeRCNNClsAware
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights: output/cascade_rcnn_cls_aware_r101_vd_fpn_1x_softnms/model_final
metric: COCO
num_classes: 81
CascadeRCNNClsAware:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: CascadeBBoxHead
bbox_assigner: CascadeBBoxAssigner
ResNet:
norm_type: bn
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
FPN:
min_level: 2
max_level: 6
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 14
sampling_ratio: 2
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_lo: [0.0, 0.0, 0.0]
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
class_aware: True
CascadeBBoxHead:
head: CascadeTwoFCHead
nms: MultiClassSoftNMS
CascadeTwoFCHead:
mlp_dim: 1024
MultiClassSoftNMS:
score_threshold: 0.01
keep_top_k: 300
softnms_sigma: 0.5
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
batch_size: 2
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size:
- 800
max_size: 1333
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size:
- 800
max_size: 1333
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
\ No newline at end of file
architecture: CascadeRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 460000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/CBResNet200_vd_pretrained.tar
weights: output/cascade_rcnn_cbr200_vd_fpn_dcnv2_nonlocal_softnms/model_final
metric: COCO
num_classes: 81
CascadeRCNN:
backbone: CBResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: CascadeBBoxHead
bbox_assigner: CascadeBBoxAssigner
CBResNet:
norm_type: bn
depth: 200
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
dcn_v2_stages: [3, 4, 5]
nonlocal_stages: [4]
repeat_num: 2
FPN:
min_level: 2
max_level: 6
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 14
sampling_ratio: 2
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_lo: [0.0, 0.0, 0.0]
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
CascadeBBoxHead:
head: CascadeTwoFCHead
nms: MultiClassSoftNMS
CascadeTwoFCHead:
mlp_dim: 1024
MultiClassSoftNMS:
score_threshold: 0.01
keep_top_k: 300
softnms_sigma: 0.5
LearningRate:
base_lr: 0.005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [340000, 440000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376, 1408]
max_size: 1600
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size:
- 1200
max_size: 2000
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
architecture: CascadeRCNNClsAware
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 460000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet200_vd_pretrained.tar
weights: output/cascade_rcnn_cls_aware_r200_vd_fpn_dcnv2_nonlocal_softnms/model_final
metric: COCO
num_classes: 81
CascadeRCNNClsAware:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: CascadeBBoxHead
bbox_assigner: CascadeBBoxAssigner
ResNet:
norm_type: bn
depth: 200
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
dcn_v2_stages: [3, 4, 5]
nonlocal_stages: [4]
FPN:
min_level: 2
max_level: 6
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 14
sampling_ratio: 2
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_lo: [0.0, 0.0, 0.0]
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
class_aware: True
CascadeBBoxHead:
head: CascadeTwoFCHead
nms: MultiClassSoftNMS
CascadeTwoFCHead:
mlp_dim: 1024
MultiClassSoftNMS:
score_threshold: 0.01
keep_top_k: 300
softnms_sigma: 0.5
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [340000, 440000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_train2017.json
image_dir: train2017
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376, 1408]
max_size: 1800
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size:
- 1200
max_size: 2000
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
drop_last: false
num_workers: 2
architecture: FasterRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/CBResNet101_vd_pretrained.tar
weights: output/faster_rcnn_cbr101_vd_dual_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: CBResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
CBResNet:
norm_type: bn
norm_decay: 0.
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
repeat_num: 2
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
# batch size per device
batch_size: 2
dataset:
dataset_dir: dataset/coco
image_dir: train2017
annotation: annotations/instances_train2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
architecture: FasterRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/CBResNet50_vd_pretrained.tar
weights: output/faster_rcnn_cbr50_vd_dual_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: CBResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
CBResNet:
norm_type: bn
norm_decay: 0.
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
repeat_num: 2
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
# batch size per device
batch_size: 2
dataset:
dataset_dir: dataset/coco
image_dir: train2017
annotation: annotations/instances_train2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: dataset/coco/annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
......@@ -49,6 +49,7 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
| ResNet50-FPN | Cascade Mask | 1 | 1x | - | 41.3 | 35.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_mask_rcnn_r50_fpn_1x.tar) |
| ResNet50-vd-FPN | Faster | 2 | 2x | 21.847 | 38.9 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar) |
| ResNet50-vd-FPN | Mask | 1 | 2x | 15.825 | 39.8 | 35.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar) |
| CBResNet50-vd-FPN | Faster | 2 | 1x | - | 39.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_cbr50_vd_dual_fpn_1x.tar) |
| ResNet101 | Faster | 1 | 1x | 9.316 | 38.3 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_1x.tar) |
| ResNet101-FPN | Faster | 1 | 1x | 17.297 | 38.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_1x.tar) |
| ResNet101-FPN | Faster | 1 | 2x | 17.246 | 39.1 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar) |
......@@ -56,12 +57,14 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
| ResNet101-vd-FPN | Faster | 1 | 1x | 17.011 | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_1x.tar) |
| ResNet101-vd-FPN | Faster | 1 | 2x | 16.934 | 40.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar) |
| ResNet101-vd-FPN | Mask | 1 | 1x | 13.105 | 41.4 | 36.8 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar) |
| CBResNet101-vd-FPN | Faster | 2 | 1x | - | 42.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_cbr101_vd_dual_fpn_1x.tar) |
| ResNeXt101-vd-64x4d-FPN | Faster | 1 | 1x | 8.815 | 42.2 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_1x.tar) |
| ResNeXt101-vd-64x4d-FPN | Faster | 1 | 2x | 8.809 | 41.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_2x.tar) |
| ResNeXt101-vd-64x4d-FPN | Mask | 1 | 1x | 7.689 | 42.9 | 37.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_x101_vd_64x4d_fpn_1x.tar) |
| ResNeXt101-vd-64x4d-FPN | Mask | 1 | 2x | 7.859 | 42.6 | 37.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_x101_vd_64x4d_fpn_2x.tar) |
| SENet154-vd-FPN | Faster | 1 | 1.44x | 3.408 | 42.9 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_se154_vd_fpn_s1x.tar) |
| SENet154-vd-FPN | Mask | 1 | 1.44x | 3.233 | 44.0 | 38.7 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_se154_vd_fpn_s1x.tar) |
| ResNet101-vd-FPN | CascadeClsAware Faster | 2 | 1x | - | 44.7(softnms) | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cls_aware_r101_vd_fpn_1x_softnms.tar) |
### Deformable ConvNets v2
......@@ -79,6 +82,10 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
| ResNet101-vd-FPN | Cascade Faster | c3-c5 | 2 | 1x | - | 46.4 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_r101_vd_fpn_1x.tar) |
| ResNeXt101-vd-FPN | Cascade Faster | c3-c5 | 2 | 1x | - | 47.3 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x.tar) |
| SENet154-vd-FPN | Cascade Mask | c3-c5 | 1 | 1.44x | - | 51.9 | 43.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_mask_rcnn_dcnv2_se154_vd_fpn_gn_s1x.tar) |
| ResNet200-vd-FPN-Nonlocal | CascadeClsAware Faster | c3-c5 | 1 | 2.5x | - | 51.7%(softnms) | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cls_aware_r200_vd_fpn_dcnv2_nonlocal_softnms.tar) |
| CBResNet200-vd-FPN-Nonlocal | Cascade Faster | c3-c5 | 1 | 2.5x | - | 53.3%(softnms) | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cbr200_vd_fpn_dcnv2_nonlocal_softnms.tar) |
#### Notes:
- Deformable ConvNets v2(dcn_v2) reference from [Deformable ConvNets v2](https://arxiv.org/abs/1811.11168).
......
......@@ -46,6 +46,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNet50-FPN | Cascade Mask | 1 | 1x | - | 41.3 | 35.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_mask_rcnn_r50_fpn_1x.tar) |
| ResNet50-vd-FPN | Faster | 2 | 2x | 21.847 | 38.9 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_2x.tar) |
| ResNet50-vd-FPN | Mask | 1 | 2x | 15.825 | 39.8 | 35.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_2x.tar) |
| CBResNet50-vd-FPN | Faster | 2 | 1x | - | 39.7 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_cbr50_vd_dual_fpn_1x.tar) |
| ResNet101 | Faster | 1 | 1x | 9.316 | 38.3 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_1x.tar) |
| ResNet101-FPN | Faster | 1 | 1x | 17.297 | 38.7 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_1x.tar) |
| ResNet101-FPN | Faster | 1 | 2x | 17.246 | 39.1 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar) |
......@@ -59,6 +60,8 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNeXt101-vd-FPN | Mask | 1 | 2x | 7.859 | 42.6 | 37.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_x101_vd_64x4d_fpn_2x.tar) |
| SENet154-vd-FPN | Faster | 1 | 1.44x | 3.408 | 42.9 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_se154_vd_fpn_s1x.tar) |
| SENet154-vd-FPN | Mask | 1 | 1.44x | 3.233 | 44.0 | 38.7 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_se154_vd_fpn_s1x.tar) |
| ResNet101-vd-FPN | CascadeClsAware Faster | 2 | 1x | - | 44.7(softnms) | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cls_aware_r101_vd_fpn_1x_softnms.tar) |
### Deformable 卷积网络v2
......@@ -76,6 +79,9 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNet101-vd-FPN | Cascade Faster | c3-c5 | 2 | 1x | - | 46.4 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_r101_vd_fpn_1x.tar) |
| ResNeXt101-vd-FPN | Cascade Faster | c3-c5 | 2 | 1x | - | 47.3 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x.tar) |
| SENet154-vd-FPN | Cascade Mask | c3-c5 | 1 | 1.44x | - | 51.9 | 43.9 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_mask_rcnn_dcnv2_se154_vd_fpn_gn_s1x.tar) |
| ResNet200-vd-FPN-Nonlocal | CascadeClsAware Faster | c3-c5 | 1 | 2.5x | - | 51.7%(softnms) | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cls_aware_r200_vd_fpn_dcnv2_nonlocal_softnms.tar) |
| CBResNet200-vd-FPN-Nonlocal | Cascade Faster | c3-c5 | 1 | 2.5x | - | 53.3%(softnms) | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_cbr200_vd_fpn_dcnv2_nonlocal_softnms.tar) |
#### 注意事项:
- Deformable卷积网络v2(dcn_v2)参考自论文[Deformable ConvNets v2](https://arxiv.org/abs/1811.11168).
......
......@@ -18,6 +18,7 @@ from . import faster_rcnn
from . import mask_rcnn
from . import cascade_rcnn
from . import cascade_mask_rcnn
from . import cascade_rcnn_cls_aware
from . import yolov3
from . import ssd
from . import retinanet
......@@ -28,6 +29,7 @@ from .faster_rcnn import *
from .mask_rcnn import *
from .cascade_rcnn import *
from .cascade_mask_rcnn import *
from .cascade_rcnn_cls_aware import *
from .yolov3 import *
from .ssd import *
from .retinanet import *
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sys
import paddle.fluid as fluid
from ppdet.core.workspace import register
__all__ = ['CascadeRCNNClsAware']
@register
class CascadeRCNNClsAware(object):
"""
Cascade R-CNN architecture, see https://arxiv.org/abs/1712.00726
This is a kind of modification of Cascade R-CNN.
Specifically, it predicts bboxes for all classes with different weights,
while the standard vesion just predicts bboxes for foreground
Args:
backbone (object): backbone instance
rpn_head (object): `RPNhead` instance
bbox_assigner (object): `BBoxAssigner` instance
roi_extractor (object): ROI extractor instance
bbox_head (object): `BBoxHead` instance
fpn (object): feature pyramid network instance
"""
__category__ = 'architecture'
__inject__ = [
'backbone', 'fpn', 'rpn_head', 'bbox_assigner', 'roi_extractor',
'bbox_head'
]
def __init__(self,
backbone,
rpn_head,
roi_extractor='FPNRoIAlign',
bbox_head='CascadeBBoxHead',
bbox_assigner='CascadeBBoxAssigner',
fpn='FPN',
):
super(CascadeRCNNClsAware, self).__init__()
assert fpn is not None, "cascade RCNN requires FPN"
self.backbone = backbone
self.fpn = fpn
self.rpn_head = rpn_head
self.bbox_assigner = bbox_assigner
self.roi_extractor = roi_extractor
self.bbox_head = bbox_head
self.bbox_clip = np.log(1000. / 16.)
# Cascade local cfg
(brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights
self.cascade_bbox_reg_weights = [
[1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0],
[1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1],
[1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2]
]
self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25]
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
im_info = feed_vars['im_info']
if mode == 'train':
gt_box = feed_vars['gt_box']
is_crowd = feed_vars['is_crowd']
gt_label = feed_vars['gt_label']
else:
im_shape = feed_vars['im_shape']
# backbone
body_feats = self.backbone(im)
# FPN
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
# rpn proposals
rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
if mode == 'train':
rpn_loss = self.rpn_head.get_loss(im_info, gt_box, is_crowd)
proposal_list = []
roi_feat_list = []
rcnn_pred_list = []
rcnn_target_list = []
bbox_pred = None
self.cascade_var_v = []
for stage in range(3):
var_v = np.array(self.cascade_bbox_reg_weights[stage], dtype="float32")
prior_box_var = fluid.layers.create_tensor(dtype="float32")
fluid.layers.assign(input=var_v, output=prior_box_var)
self.cascade_var_v.append(prior_box_var)
self.cascade_decoded_box = []
self.cascade_cls_prob = []
for stage in range(3):
if stage > 0:
pool_rois = decoded_assign_box
else:
pool_rois = rpn_rois
if mode == "train":
self.cascade_var_v[stage].stop_gradient = True
outs = self.bbox_assigner(
input_rois=pool_rois, feed_vars=feed_vars, curr_stage=stage)
pool_rois = outs[0]
rcnn_target_list.append( outs )
# extract roi features
roi_feat = self.roi_extractor(body_feats, pool_rois, spatial_scale)
roi_feat_list.append(roi_feat)
# bbox head
cls_score, bbox_pred = self.bbox_head.get_output(
roi_feat,
cls_agnostic_bbox_reg=self.bbox_head.num_classes,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[stage],
name='_' + str(stage + 1) )
cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
decoded_box, decoded_assign_box = fluid.layers.box_decoder_and_assign(
pool_rois,
self.cascade_var_v[stage],
bbox_pred,
cls_prob,
self.bbox_clip)
if mode == "train":
decoded_box.stop_gradient = True
decoded_assign_box.stop_gradient = True
else:
self.cascade_cls_prob.append( cls_prob )
self.cascade_decoded_box.append(decoded_box)
rcnn_pred_list.append((cls_score, bbox_pred))
# out loop
if mode == 'train':
loss = self.bbox_head.get_loss(rcnn_pred_list,
rcnn_target_list,
self.cascade_rcnn_loss_weight)
loss.update(rpn_loss)
total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
pred = self.bbox_head.get_prediction_cls_aware(
im_info, im_shape,
self.cascade_cls_prob,
self.cascade_decoded_box,
self.cascade_bbox_reg_weights)
return pred
def train(self, feed_vars):
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self.build(feed_vars, 'test')
......@@ -23,6 +23,7 @@ from . import fpn
from . import vgg
from . import blazenet
from . import faceboxnet
from . import cb_resnet
from .resnet import *
from .resnext import *
......@@ -33,3 +34,4 @@ from .fpn import *
from .vgg import *
from .blazenet import *
from .faceboxnet import *
from .cb_resnet import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from ppdet.core.workspace import register, serializable
from numbers import Integral
from .name_adapter import NameAdapter
from .nonlocal_helper import add_space_nonlocal
__all__ = ['CBResNet']
@register
@serializable
class CBResNet(object):
"""
CBNet, see https://arxiv.org/abs/1909.03625
Args:
depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
repeat_num (int): number of repeat for backbone
Attention:
1. Here we set the ResNet as the base backbone.
2. All the pretraned params are copied from corresponding names,
but with different names to avoid name refliction.
"""
def __init__(self,
depth=50,
freeze_at=2,
norm_type='bn',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[2, 3, 4, 5],
dcn_v2_stages=[],
nonlocal_stages = [],
repeat_num = 2):
super(CBResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [18, 34, 50, 101, 152, 200], \
"depth {} not in [18, 34, 50, 101, 152, 200]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.dcn_v2_stages = dcn_v2_stages
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.repeat_num = repeat_num
self.curr_level = 0
self.depth_cfg = {
18: ([2, 2, 2, 2], self.basicblock),
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
101: ([3, 4, 23, 3], self.bottleneck),
152: ([3, 8, 36, 3], self.bottleneck),
200: ([3, 12, 48, 3], self.bottleneck),
}
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50 : 2,
101 : 5,
152 : 8,
200 : 12,
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
def _conv_offset(self, input, filter_size, stride, padding, act=None, name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(
initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(
initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
dcn=False):
if not dcn:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights_"+str(self.curr_level)),
bias_attr=False)
else:
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=name + "_conv_offset_" + str(self.curr_level))
offset_channel = filter_size ** 2 * 2
mask_channel = filter_size ** 2
offset, mask = fluid.layers.split(
input=offset_mask,
num_or_sections=[offset_channel, mask_channel],
dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=name + "_weights_"+str(self.curr_level)),
bias_attr=False)
bn_name = self.na.fix_conv_norm_name(name)
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale_'+str(self.curr_level),
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset_'+str(self.curr_level),
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1_'+str(self.curr_level),
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean_'+str(self.curr_level),
moving_variance_name=bn_name + '_variance_'+str(self.curr_level),
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
assert False, "deprecated!!!"
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self, input, num_filters, stride, is_first, name, dcn=False):
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn=(i==1 and dcn))
short = self._shortcut(
input,
num_filters * expand,
stride,
is_first=is_first,
name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(
input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(
x=short, y=residual, act='relu')
def basicblock(self, input, num_filters, stride, is_first, name, dcn=False):
assert dcn is False, "Not implemented yet."
conv0 = self._conv_norm(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self._conv_norm(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self._shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[self.depth] if stage_num==4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn=dcn)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}_lvl{}".format( stage_num, self.curr_level )
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(
conv, dim_in, dim_in,
nonlocal_name + '_{}'.format(i), int(dim_in / 2) )
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv1_1_name= "conv1_1"
conv1_2_name= "conv1_2"
conv1_3_name= "conv1_3"
conv_def = [
[out_chan // 2, 3, 2, conv1_1_name],
[out_chan // 2, 3, 1, conv1_2_name],
[out_chan, 3, 1, conv1_3_name],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(
input=input,
num_filters=c,
filter_size=k,
stride=s,
act='relu',
name=_name)
output = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return output
def connect( self, left, right, name ):
ch_right = right.shape[1]
conv = self._conv_norm( left,
num_filters=ch_right,
filter_size=1,
stride=1,
act="relu",
name=name+"_connect")
shape = fluid.layers.shape(right)
shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4])
out_shape_ = shape_hw
out_shape = fluid.layers.cast(out_shape_, dtype='int32')
out_shape.stop_gradient = True
conv = fluid.layers.resize_nearest(
conv, scale=2., actual_shape=out_shape)
output = fluid.layers.elementwise_add(x=right, y=conv)
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
self.curr_level = 0
res = self.c1_stage(input)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
for num in range(1, self.repeat_num):
self.curr_level = num
res = self.c1_stage(input)
for i in range( len(res_endpoints) ):
res = self.connect( res_endpoints[i], res, "test_c"+str(i+1) )
res = self.layer_warp(res, i+2)
res_endpoints[i] = res
if self.freeze_at >= i+2:
res.stop_gradient = True
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
nonlocal_params = {
"use_zero_init_conv" : False,
"conv_init_std" : 0.01,
"no_bias" : True,
"use_maxpool" : False,
"use_softmax" : True,
"use_bn" : False,
"use_scale" : True, # vital for the model prformance!!!
"use_affine" : False,
"bn_momentum" : 0.9,
"bn_epsilon" : 1.0000001e-5,
"bn_init_gamma" : 0.9,
"weight_decay_bn":1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = 2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape( theta )
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1) )
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name = prefix + '_affinity')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale = dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(theta_phi_sc, name = prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name = prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(t, shape=list(theta_shape), actual_shape=theta_shape_op )
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner ):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name = prefix + '_sum')
return output
......@@ -27,6 +27,7 @@ from paddle.fluid.initializer import Constant
from ppdet.core.workspace import register, serializable
from numbers import Integral
from .nonlocal import add_space_nonlocal
from .name_adapter import NameAdapter
__all__ = ['ResNet', 'ResNetC5']
......@@ -46,6 +47,7 @@ class ResNet(object):
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
......@@ -58,18 +60,21 @@ class ResNet(object):
variant='b',
feature_maps=[2, 3, 4, 5],
dcn_v2_stages=[],
weight_prefix_name=''):
weight_prefix_name='',
nonlocal_stages=[]):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [18, 34, 50, 101, 152], \
"depth {} not in [18, 34, 50, 101, 152]"
assert depth in [18, 34, 50, 101, 152, 200], \
"depth {} not in [18, 34, 50, 101, 152, 200]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.freeze_at = freeze_at
......@@ -85,12 +90,21 @@ class ResNet(object):
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
101: ([3, 4, 23, 3], self.bottleneck),
152: ([3, 8, 36, 3], self.bottleneck)
152: ([3, 8, 36, 3], self.bottleneck),
200: ([3, 12, 48, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50 : 2,
101 : 5,
152 : 8,
200 : 12,
}
def _conv_offset(self,
input,
......@@ -340,6 +354,11 @@ class ResNet(object):
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn_v2 = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[self.depth] if stage_num==4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
......@@ -354,6 +373,14 @@ class ResNet(object):
is_first=is_first,
name=conv_name,
dcn_v2=dcn_v2)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}".format( stage_num )
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(
conv, dim_in, dim_in,
nonlocal_name + '_{}'.format(i), int(dim_in / 2) )
return conv
def c1_stage(self, input):
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numbers import Integral
from paddle import fluid
......@@ -22,7 +23,8 @@ from ppdet.core.workspace import register, serializable
__all__ = [
'AnchorGenerator', 'RPNTargetAssign', 'GenerateProposals', 'MultiClassNMS',
'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm'
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'MultiClassSoftNMS'
]
......@@ -205,6 +207,113 @@ class MultiClassNMS(object):
self.nms_eta = nms_eta
self.background_label = background_label
@register
@serializable
class MultiClassSoftNMS(object):
def __init__(self,
score_threshold=0.01,
keep_top_k=300,
softnms_sigma=0.5,
normalized=False,
background_label=0,
):
super(MultiClassSoftNMS, self).__init__()
self.score_threshold = score_threshold
self.keep_top_k = keep_top_k
self.softnms_sigma = softnms_sigma
self.normalized = normalized
self.background_label = background_label
def __call__( self, bboxes, scores ):
def create_tmp_var(program, name, dtype, shape, lod_leval):
return program.current_block().create_var(name=name,
dtype=dtype,
shape=shape,
lod_leval=lod_leval)
def _soft_nms_for_cls(dets, sigma, thres):
"""soft_nms_for_cls"""
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def _soft_nms(bboxes, scores):
bboxes = np.array(bboxes)
scores = np.array(scores)
class_nums = scores.shape[-1]
softnms_thres = self.score_threshold
softnms_sigma = self.softnms_sigma
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
cls_ids = [[] for _ in range(class_nums)]
start_idx = 1 if self.background_label == 0 else 0
for j in range(start_idx, class_nums):
inds = np.where(scores[:, j] >= softnms_thres)[0]
scores_j = scores[inds, j]
rois_j = bboxes[inds, j, :]
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(np.float32, copy=False)
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _soft_nms_for_cls( dets_j, sigma=softnms_sigma, thres=softnms_thres )
cls_ids[j] = np.array( [j]*cls_boxes[j].shape[0] ).reshape(-1,1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
pred_result = np.hstack( [cls_ids, cls_boxes] )
# Limit to max_per_image detections **over all classes**
image_scores = cls_boxes[:,0]
if len(image_scores) > keep_top_k:
image_thresh = np.sort(image_scores)[-keep_top_k]
keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
pred_result = pred_result[keep, :]
res = fluid.LoDTensor()
res.set_lod([[0, pred_result.shape[0]]])
if pred_result.shape[0] == 0:
pred_result = np.array( [[1]], dtype=np.float32 )
res.set(pred_result, fluid.CPUPlace())
return res
pred_result = create_tmp_var(fluid.default_main_program(),
name='softnms_pred_result',
dtype='float32',
shape=[6],
lod_leval=1)
fluid.layers.py_func(func=_soft_nms,
x=[bboxes, scores], out=pred_result)
return pred_result
@register
class BBoxAssigner(object):
......
......@@ -219,6 +219,32 @@ class CascadeBBoxHead(object):
return {'bbox': box_out, 'score': boxes_cls_prob_mean}
pred_result = self.nms(bboxes=box_out, scores=boxes_cls_prob_mean)
return {"bbox": pred_result}
def get_prediction_cls_aware(self,
im_info,
im_shape,
cascade_cls_prob,
cascade_decoded_box,
cascade_bbox_reg_weights):
'''
get_prediction_cls_aware: predict bbox for each class
'''
cascade_num_stage = 3
cascade_eval_weight = [0.2, 0.3, 0.5]
# merge 3 stages results
sum_cascade_cls_prob = sum([ prob*cascade_eval_weight[idx] for idx, prob in enumerate(cascade_cls_prob) ])
sum_cascade_decoded_box = sum([ bbox*cascade_eval_weight[idx] for idx, bbox in enumerate(cascade_decoded_box) ])
self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale_lod = fluid.layers.sequence_expand(self.im_scale, sum_cascade_decoded_box)
sum_cascade_decoded_box = sum_cascade_decoded_box / im_scale_lod
decoded_bbox = sum_cascade_decoded_box
decoded_bbox = fluid.layers.reshape(decoded_bbox, shape=(-1, self.num_classes, 4) )
box_out = fluid.layers.box_clip(input=decoded_bbox, im_info=im_shape)
pred_result = self.nms(bboxes=box_out, scores=sum_cascade_cls_prob)
return {"bbox": pred_result}
@register
......
......@@ -35,8 +35,9 @@ class CascadeBBoxAssigner(object):
bg_thresh_hi=[0.5, 0.6, 0.7],
bg_thresh_lo=[0., 0., 0.],
bbox_reg_weights=[10, 20, 30],
shuffle_before_sample=True,
num_classes=81,
shuffle_before_sample=True):
class_aware=False):
super(CascadeBBoxAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
......@@ -46,6 +47,7 @@ class CascadeBBoxAssigner(object):
self.bbox_reg_weights = bbox_reg_weights
self.class_nums = num_classes
self.use_random = shuffle_before_sample
self.class_aware = class_aware
def __call__(self, input_rois, feed_vars, curr_stage):
......@@ -67,7 +69,7 @@ class CascadeBBoxAssigner(object):
bg_thresh_lo=self.bg_thresh_lo[curr_stage],
bbox_reg_weights=curr_bbox_reg_w,
use_random=self.use_random,
class_nums=2,
class_nums=self.class_nums if self.class_aware else 2,
is_cls_agnostic=True,
is_cascade_rcnn=True if curr_stage > 0 else False)
is_cascade_rcnn=True if curr_stage > 0 and not self.class_aware else False)
return outs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册