提交 8f58f396 编写于 作者: X Xianzhi Du 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 371847148
上级 71c98d5d
# Expect to reach: box mAP: 51.6%, mask mAP: 44.5% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.5
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 4.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1280, 1280, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '143'
type: 'spinenet'
decoder:
type: 'identity'
detection_head:
cascade_class_ensemble: true
class_agnostic_bbox_pred: true
rpn_head:
num_convs: 2
num_filters: 256
roi_sampler:
cascade_iou_thresholds: [0.7]
foreground_iou_threshold: 0.6
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
activation: 'swish'
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 162050
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [148160, 157420]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [640, 640, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '49'
type: 'spinenet'
decoder:
type: 'identity'
detection_head:
cascade_class_ensemble: true
class_agnostic_bbox_pred: true
rpn_head:
num_convs: 2
num_filters: 256
roi_sampler:
cascade_iou_thresholds: [0.7]
foreground_iou_threshold: 0.6
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
activation: 'swish'
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.28, 0.028, 0.0028]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1024, 1024, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '96'
type: 'spinenet'
decoder:
type: 'identity'
detection_head:
cascade_class_ensemble: true
class_agnostic_bbox_pred: true
rpn_head:
num_convs: 2
num_filters: 256
roi_sampler:
cascade_iou_thresholds: [0.7]
foreground_iou_threshold: 0.6
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
activation: 'swish'
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
......@@ -114,7 +114,28 @@ def build_maskrcnn(
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
name='detection_head')
if roi_sampler_config.cascade_iou_thresholds:
detection_head_cascade = [detection_head]
for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)):
detection_head = instance_heads.DetectionHead(
num_classes=model_config.num_classes,
num_convs=detection_head_config.num_convs,
num_filters=detection_head_config.num_filters,
use_separable_conv=detection_head_config.use_separable_conv,
num_fcs=detection_head_config.num_fcs,
fc_dims=detection_head_config.fc_dims,
class_agnostic_bbox_pred=detection_head_config
.class_agnostic_bbox_pred,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer,
name='detection_head_{}'.format(cascade_num + 1))
detection_head_cascade.append(detection_head)
detection_head = detection_head_cascade
roi_generator_obj = roi_generator.MultilevelROIGenerator(
pre_nms_top_k=roi_generator_config.pre_nms_top_k,
......
......@@ -31,7 +31,8 @@ class MaskRCNNModel(tf.keras.Model):
backbone: tf.keras.Model,
decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer,
detection_head: tf.keras.layers.Layer,
detection_head: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_generator: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
......@@ -54,7 +55,7 @@ class MaskRCNNModel(tf.keras.Model):
backbone: `tf.keras.Model`, the backbone network.
decoder: `tf.keras.Model`, the decoder network.
rpn_head: the RPN head.
detection_head: the detection head.
detection_head: the detection head or a list of heads.
roi_generator: the ROI generator.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
......@@ -104,7 +105,10 @@ class MaskRCNNModel(tf.keras.Model):
self.backbone = backbone
self.decoder = decoder
self.rpn_head = rpn_head
self.detection_head = detection_head
if not isinstance(detection_head, (list, tuple)):
self.detection_head = [detection_head]
else:
self.detection_head = detection_head
self.roi_generator = roi_generator
if not isinstance(roi_sampler, (list, tuple)):
self.roi_sampler = [roi_sampler]
......@@ -191,7 +195,7 @@ class MaskRCNNModel(tf.keras.Model):
gt_classes=gt_classes,
training=training,
model_outputs=model_outputs,
layer_num=cascade_num,
cascade_num=cascade_num,
regression_weights=regression_weights)
all_class_outputs.append(class_outputs)
......@@ -266,7 +270,7 @@ class MaskRCNNModel(tf.keras.Model):
return model_outputs
def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
model_outputs, layer_num, regression_weights):
model_outputs, cascade_num, regression_weights):
"""Runs the frcnn head that does both class and box prediction.
Args:
......@@ -279,7 +283,7 @@ class MaskRCNNModel(tf.keras.Model):
classes. It is padded with -1s to indicate the invalid classes.
training: `bool`, if model is training or being evaluated.
model_outputs: `dict`, used for storing outputs used for eval and losses.
layer_num: `int`, the current frcnn layer in the cascade.
cascade_num: `int`, the current frcnn layer in the cascade.
regression_weights: `list`, weights used for l1 loss in bounding box
regression.
......@@ -305,7 +309,7 @@ class MaskRCNNModel(tf.keras.Model):
if training and gt_boxes is not None:
rois = tf.stop_gradient(rois)
current_roi_sampler = self.roi_sampler[layer_num]
current_roi_sampler = self.roi_sampler[cascade_num]
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
current_roi_sampler(rois, gt_boxes, gt_classes))
# Create bounding box training targets.
......@@ -317,10 +321,11 @@ class MaskRCNNModel(tf.keras.Model):
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets_{}'.format(layer_num)
if layer_num else 'class_targets':
'class_targets_{}'.format(cascade_num)
if cascade_num else 'class_targets':
matched_gt_classes,
'box_targets_{}'.format(layer_num) if layer_num else 'box_targets':
'box_targets_{}'.format(cascade_num)
if cascade_num else 'box_targets':
box_targets,
})
......@@ -328,12 +333,14 @@ class MaskRCNNModel(tf.keras.Model):
roi_features = self.roi_aligner(features, rois)
# Run frcnn head to get class and bbox predictions.
class_outputs, box_outputs = self.detection_head(roi_features)
current_detection_head = self.detection_head[cascade_num]
class_outputs, box_outputs = current_detection_head(roi_features)
model_outputs.update({
'class_outputs_{}'.format(layer_num) if layer_num else 'class_outputs':
'class_outputs_{}'.format(cascade_num)
if cascade_num else 'class_outputs':
class_outputs,
'box_outputs_{}'.format(layer_num) if layer_num else 'box_outputs':
'box_outputs_{}'.format(cascade_num) if cascade_num else 'box_outputs':
box_outputs,
})
return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
......
......@@ -373,7 +373,7 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head)
detection_head=[detection_head])
if include_mask:
expect_checkpoint_items['mask_head'] = mask_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册