提交 59393a8d 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Imports the mask-rcnn config.

PiperOrigin-RevId: 283987800
上级 c115444f
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base config template."""
# pylint: disable=line-too-long
# For ResNet, this freezes the variables of the first conv1 and conv2_x
# layers [1], which leads to higher training speed and slightly better testing
# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
# is able to capture low-level features such as edges; therefore, it does not
# need to be fine-tuned for the detection task.
# Note that we need to trailing `/` to avoid the incorrect match.
# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
BASE_CFG = {
'model_dir': '',
'use_tpu': True,
'isolate_session_state': False,
'train': {
'iterations_per_loop': 100,
'train_batch_size': 64,
'total_steps': 22500,
'num_cores_per_replica': None,
'input_partition_dims': None,
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
},
'learning_rate': {
'type': 'step',
'warmup_learning_rate': 0.0067,
'warmup_steps': 500,
'init_learning_rate': 0.08,
'learning_rate_levels': [0.008, 0.0008],
'learning_rate_steps': [15000, 20000],
'total_steps': 22500,
},
'checkpoint': {
'path': '',
'prefix': '',
},
'frozen_variable_prefix': RESNET_FROZEN_VAR_PREFIX,
'train_file_pattern': '',
'train_dataset_type': 'tfrecord',
'transpose_input': True,
'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
'l2_weight_decay': 0.0001,
'gradient_clip_norm': 0.0,
},
'eval': {
'eval_batch_size': 8,
'eval_samples': 5000,
'min_eval_interval': 180,
'eval_timeout': None,
'num_steps_per_eval': 1000,
'type': 'box',
'use_json_file': True,
'val_json_file': '',
'eval_file_pattern': '',
'eval_dataset_type': 'tfrecord',
},
'predict': {
'predict_batch_size': 8,
},
'anchor': {
'min_level': 3,
'max_level': 7,
'num_scales': 3,
'aspect_ratios': [1.0, 2.0, 0.5],
'anchor_size': 4.0,
},
'resnet': {
'resnet_depth': 50,
'dropblock': {
'dropblock_keep_prob': None,
'dropblock_size': None,
},
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'fpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'use_separable_conv': False,
'use_batch_norm': True,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'nasfpn': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'num_repeats': 5,
'use_separable_conv': False,
'dropblock': {
'dropblock_keep_prob': None,
'dropblock_size': None,
},
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1': {
'min_level': 3,
'max_level': 7,
'fpn_feat_dims': 256,
'num_repeats': 5,
'use_separable_conv': False,
'dropblock': {
'dropblock_keep_prob': None,
'dropblock_size': None,
},
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
'nodes': None
},
# tunable_nasfpn:strip_end
'postprocess': {
'use_batched_nms': False,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 5000,
},
'enable_summary': False,
}
# pylint: enable=line-too-long
......@@ -14,6 +14,7 @@
# ==============================================================================
"""Factory to provide model configs."""
from official.vision.detection.configs import maskrcnn_config
from official.vision.detection.configs import retinanet_config
from official.modeling.hyperparams import params_dict
......@@ -23,6 +24,9 @@ def config_generator(model):
if model == 'retinanet':
default_config = retinanet_config.RETINANET_CFG
restrictions = retinanet_config.RETINANET_RESTRICTIONS
elif model == 'mask_rcnn':
default_config = maskrcnn_config.MASKRCNN_CFG
restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
else:
raise ValueError('Model %s is not supported.' % model)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config template to train Mask R-CNN."""
from official.vision.detection.configs import base_config
from official.modeling.hyperparams import params_dict
# pylint: disable=line-too-long
MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
MASKRCNN_CFG.override({
'type': 'mask_rcnn',
'eval': {
'type': 'box_and_mask',
},
'architecture': {
'parser': 'maskrcnn_parser',
'backbone': 'resnet',
'multilevel_features': 'fpn',
'use_bfloat16': True,
'include_mask': True,
},
'maskrcnn_parser': {
'use_bfloat16': True,
'output_size': [1024, 1024],
'rpn_match_threshold': 0.7,
'rpn_unmatched_threshold': 0.3,
'rpn_batch_size_per_im': 256,
'rpn_fg_fraction': 0.5,
'aug_rand_hflip': True,
'aug_scale_min': 1.0,
'aug_scale_max': 1.0,
'skip_crowd_during_training': True,
'max_num_instances': 100,
'include_mask': True,
'mask_crop_size': 112,
},
'anchor': {
'min_level': 2,
'max_level': 6,
'num_scales': 1,
'anchor_size': 8,
},
'fpn': {
'min_level': 2,
'max_level': 6,
},
'nasfpn': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_begin
'tunable_nasfpn_v1': {
'min_level': 2,
'max_level': 6,
},
# tunable_nasfpn:strip_end
'rpn_head': {
'min_level': 2,
'max_level': 6,
'anchors_per_location': 3,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'frcnn_head': {
# Note that `num_classes` is the total number of classes including
# one background classes whose index is 0.
'num_classes': 91,
'fast_rcnn_mlp_head_dim': 1024,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'mrcnn_head': {
'num_classes': 91,
'mask_target_size': 28,
'use_batch_norm': False,
'batch_norm': {
'batch_norm_momentum': 0.997,
'batch_norm_epsilon': 1e-4,
'batch_norm_trainable': True,
'use_sync_bn': False,
},
},
'rpn_score_loss': {
'rpn_batch_size_per_im': 256,
},
'rpn_box_loss': {
'huber_loss_delta': 1.0 / 9.0,
},
'frcnn_box_loss': {
'huber_loss_delta': 1.0,
},
'roi_proposal': {
'rpn_pre_nms_top_k': 2000,
'rpn_post_nms_top_k': 1000,
'rpn_nms_threshold': 0.7,
'rpn_score_threshold': 0.0,
'rpn_min_size_threshold': 0.0,
'test_rpn_pre_nms_top_k': 1000,
'test_rpn_post_nms_top_k': 1000,
'test_rpn_nms_threshold': 0.7,
'test_rpn_score_threshold': 0.0,
'test_rpn_min_size_threshold': 0.0,
'use_batched_nms': False,
},
'roi_sampling': {
'num_samples_per_image': 512,
'fg_fraction': 0.25,
'fg_iou_thresh': 0.5,
'bg_iou_thresh_hi': 0.5,
'bg_iou_thresh_lo': 0.0,
'mix_gt_boxes': True,
},
'mask_sampling': {
'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
'mask_target_size': 28,
},
'postprocess': {
'use_batched_nms': False,
'max_total_size': 100,
'nms_iou_threshold': 0.5,
'score_threshold': 0.05,
'pre_nms_num_boxes': 1000,
},
}, is_strict=False)
MASKRCNN_RESTRICTIONS = [
'architecture.use_bfloat16 == maskrcnn_parser.use_bfloat16',
'architecture.include_mask == maskrcnn_parser.include_mask',
'anchor.min_level == rpn_head.min_level',
'anchor.max_level == rpn_head.max_level',
'mrcnn_head.mask_target_size == mask_sampling.mask_target_size',
]
# pylint: enable=line-too-long
......@@ -39,7 +39,7 @@ RETINANET_CFG = {
'optimizer': {
'type': 'momentum',
'momentum': 0.9,
'nesterov': True,
'nesterov': True, # `False` is better for TPU v3-128.
},
'learning_rate': {
'type': 'step',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册