提交 9597360e 编写于 作者: G Guanghua Yu 提交者: qingqing01

update modelzoo and fix configs (#2789)

* fix resnet for resnet34

* fix ResNet18 and ResNet34

* update modelzoo and fix configs

* fix configs

* fix resnet bn
上级 3371a8c5
...@@ -4,7 +4,7 @@ eval_feed: FasterRCNNEvalFeed ...@@ -4,7 +4,7 @@ eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed test_feed: FasterRCNNTestFeed
max_iters: 180000 max_iters: 180000
snapshot_iter: 10000 snapshot_iter: 10000
use_gpu: True use_gpu: true
log_smooth_window: 20 log_smooth_window: 20
save_dir: output save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar
...@@ -118,7 +118,7 @@ FasterRCNNTrainFeed: ...@@ -118,7 +118,7 @@ FasterRCNNTrainFeed:
- !PadBatch - !PadBatch
pad_to_stride: 32 pad_to_stride: 32
num_workers: 2 num_workers: 2
shuffle: True shuffle: true
FasterRCNNEvalFeed: FasterRCNNEvalFeed:
batch_size: 1 batch_size: 1
...@@ -139,4 +139,4 @@ FasterRCNNTestFeed: ...@@ -139,4 +139,4 @@ FasterRCNNTestFeed:
- !PadBatch - !PadBatch
pad_to_stride: 32 pad_to_stride: 32
num_workers: 2 num_workers: 2
shuffle: False shuffle: false
architecture: FasterRCNN
train_feed: FasterRCNNTrainFeed
eval_feed: FasterRCNNEvalFeed
test_feed: FasterRCNNTestFeed
max_iters: 360000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNeXt101_vd_64x4d_pretrained.tar
weights: output/faster_rcnn_x101_vd_64x4d_fpn_1x/model_final
metric: COCO
FasterRCNN:
backbone: ResNeXt
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNeXt:
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
group_width: 4
groups: 64
norm_type: affine_channel
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
num_classes: 81
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
num_classes: 81
TwoFCHead:
num_chan: 1024
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 320000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
FasterRCNNTrainFeed:
# batch size per device
batch_size: 1
dataset:
dataset_dir: dataset/coco
image_dir: train2017
annotation: annotations/instances_train2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
shuffle: true
FasterRCNNEvalFeed:
batch_size: 1
dataset:
dataset_dir: dataset/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
FasterRCNNTestFeed:
batch_size: 1
dataset:
annotation: annotations/instances_val2017.json
batch_transforms:
- !PadBatch
pad_to_stride: 32
num_workers: 2
shuffle: false
...@@ -49,8 +49,9 @@ The backbone models pretrained on ImageNet are available. All backbone models ar ...@@ -49,8 +49,9 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
| ResNet101-FPN | Faster | 1 | 2x | 39.1 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar) | | ResNet101-FPN | Faster | 1 | 2x | 39.1 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar) |
| ResNet101-FPN | Mask | 1 | 1x | 39.5 | 35.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar) | | ResNet101-FPN | Mask | 1 | 1x | 39.5 | 35.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_fpn_1x.tar) |
| ResNet101-vd-FPN | Faster | 1 | 1x | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_1x.tar) | | ResNet101-vd-FPN | Faster | 1 | 1x | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_1x.tar) |
| ResNet101-vd-FPN | Faster | 1 | 2x | 40.6 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_fpn_2x.tar) | | ResNet101-vd-FPN | Faster | 1 | 2x | 40.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar) |
| ResNeXt101-vd-FPN | Faster | 1 | 1x | 42.2 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_1x.tar) | | ResNeXt101-vd-FPN | Faster | 1 | 1x | 42.2 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_1x.tar) |
| ResNeXt101-vd-FPN | Faster | 1 | 2x | 41.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_2x.tar) |
| SENet154-vd-FPN | Faster | 1 | 1.44x | 42.9 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_se154_vd_fpn_s1x.tar) | | SENet154-vd-FPN | Faster | 1 | 1.44x | 42.9 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_se154_vd_fpn_s1x.tar) |
| SENet154-vd-FPN | Mask | 1 | 1.44x | 44.0 | 38.7 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_se154_vd_fpn_s1x.tar) | | SENet154-vd-FPN | Mask | 1 | 1.44x | 44.0 | 38.7 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_se154_vd_fpn_s1x.tar) |
......
...@@ -119,6 +119,7 @@ class ResNet(object): ...@@ -119,6 +119,7 @@ class ResNet(object):
regularizer=L2Decay(norm_decay)) regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']: if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm( out = fluid.layers.batch_norm(
input=conv, input=conv,
act=act, act=act,
...@@ -126,7 +127,8 @@ class ResNet(object): ...@@ -126,7 +127,8 @@ class ResNet(object):
param_attr=pattr, param_attr=pattr,
bias_attr=battr, bias_attr=battr,
moving_mean_name=bn_name + '_mean', moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', ) moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name) scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name) bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel': elif self.norm_type == 'affine_channel':
......
...@@ -141,12 +141,12 @@ def main(): ...@@ -141,12 +141,12 @@ def main():
exe.run(startup_prog) exe.run(startup_prog)
freeze_bn = getattr(model.backbone, 'freeze_norm', False) fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step() start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and freeze_bn: elif cfg.pretrain_weights and fuse_bn:
checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights) checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
elif cfg.pretrain_weights: elif cfg.pretrain_weights:
checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights) checkpoint.load_pretrain(exe, train_prog, cfg.pretrain_weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册