提交 ed0d9c71 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

MobileNetV2 backbone for maskrcnn.

PiperOrigin-RevId: 453472688
上级 cb386cfb
# Expect to reach: box mAP: 33.3%, mask mAP: 29.4% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: gs://**/mobilenetv2_gpu/22984194/ckpt-625500
init_checkpoint_modules: 'backbone'
train_data:
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
num_scales: 3
detection_generator:
pre_nms_top_k: 1000
......@@ -524,3 +524,91 @@ def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig:
'task.model.max_level == task.model.backbone.spinenet.max_level',
])
return config
@exp_factory.register_config_factory('maskrcnn_mobilenet_coco')
def maskrcnn_mobilenet_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Mask R-CNN with MobileNet backbone."""
steps_per_epoch = 232
coco_val_samples = 5000
train_batch_size = 512
eval_batch_size = 512
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask(
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=MaskRCNN(
backbone=backbones.Backbone(
type='mobilenet',
mobilenet=backbones.MobileNet(model_id='MobileNetV2')),
decoder=decoders.Decoder(
type='fpn',
fpn=decoders.FPN(num_filters=128, use_separable_conv=True)),
rpn_head=RPNHead(use_separable_conv=True,
num_filters=128), # 1/2 of original channels.
detection_head=DetectionHead(
use_separable_conv=True, num_filters=128,
fc_dims=512), # 1/2 of original channels.
mask_head=MaskHead(use_separable_conv=True,
num_filters=128), # 1/2 of original channels.
anchor=Anchor(anchor_size=3),
norm_activation=common.NormActivation(
activation='relu6',
norm_momentum=0.99,
norm_epsilon=0.001,
use_sync_bn=True),
num_classes=91,
input_size=[512, 512, 3],
min_level=3,
max_level=6,
include_mask=True),
losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=Parser(
aug_rand_hflip=True, aug_scale_min=0.5, aug_scale_max=2.0)),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 350,
validation_steps=coco_val_samples // eval_batch_size,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [
steps_per_epoch * 320, steps_per_epoch * 340
],
'values': [0.32, 0.032, 0.0032],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2000,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
return config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册