提交 f4477a29 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339773662
上级 9139a7b9
......@@ -9,7 +9,7 @@ task:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 8
output_stride: 16
norm_activation:
activation: 'swish'
losses:
......
......@@ -9,7 +9,7 @@ task:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 8
output_stride: 16
norm_activation:
activation: 'swish'
losses:
......
# Dilated ResNet-50 Pascal segmentation. 80.89 mean IOU.
# Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
......
# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
......@@ -15,8 +15,11 @@
# ==============================================================================
"""Semantic segmentation configuration definition."""
import os
from typing import List, Union, Optional
from typing import List, Optional, Union
import dataclasses
import numpy as np
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
......@@ -48,6 +51,10 @@ class SegmentationHead(hyperparams.Config):
num_convs: int = 2
num_filters: int = 256
upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, or deeplabv3plus
# deeplabv3plus feature fusion params
low_level: int = 2
low_level_num_filters: int = 48
@dataclasses.dataclass
......@@ -109,6 +116,9 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
train_batch_size = 16
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 8
aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
......@@ -117,11 +127,99 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
input_size=[512, 512, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=8)),
model_id=50, output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP(
level=3, dilation_rates=[12, 24, 36])),
head=SegmentationHead(level=3, num_convs=0),
level=level, dilation_rates=aspp_dilation_rates)),
head=SegmentationHead(level=level, num_convs=0),
norm_activation=common.NormActivation(
activation='swish',
norm_momentum=0.9997,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
# resnet50
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=45 * steps_per_epoch,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.007,
'decay_steps': 45 * steps_per_epoch,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('seg_deeplabv3plus_pascal')
def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3+."""
train_batch_size = 16
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
aspp_dilation_rates = [6, 12, 18] # [12, 24, 36] if output_stride = 8
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[512, 512, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates)),
head=SegmentationHead(
level=level,
num_convs=2,
feature_fusion='deeplabv3plus',
low_level=2,
low_level_num_filters=48),
norm_activation=common.NormActivation(
activation='swish',
norm_momentum=0.9997,
......
......@@ -27,7 +27,8 @@ from official.vision.beta.configs import semantic_segmentation as exp_cfg
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('seg_deeplabv3_pascal',),)
@parameterized.parameters(('seg_deeplabv3_pascal',),
('seg_deeplabv3plus_pascal',))
def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
......
......@@ -263,6 +263,9 @@ def build_segmentation_model(
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor,
feature_fusion=head_config.feature_fusion,
low_level=head_config.low_level,
low_level_num_filters=head_config.low_level_num_filters,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......
......@@ -14,7 +14,6 @@
# ==============================================================================
"""Segmentation heads."""
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
......@@ -31,6 +30,9 @@ class SegmentationHead(tf.keras.layers.Layer):
num_convs=2,
num_filters=256,
upsample_factor=1,
feature_fusion=None,
low_level=2,
low_level_num_filters=48,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
......@@ -50,6 +52,14 @@ class SegmentationHead(tf.keras.layers.Layer):
Default is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, or None. If `deeplabv3plus`,
features from decoder_features[level] will be fused with
low level feature maps from backbone.
low_level: `int`, backbone level to be used for feature fusion. This arg
is used when feature_fusion is set to deeplabv3plus.
low_level_num_filters: `int`, reduced number of filters for the low
level features before fusing it with higher level features. This args is
only used when feature_fusion is set to deeplabv3plus.
activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization
......@@ -63,12 +73,16 @@ class SegmentationHead(tf.keras.layers.Layer):
**kwargs: other keyword arguments passed to Layer.
"""
super(SegmentationHead, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
'low_level': low_level,
'low_level_num_filters': low_level_num_filters,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
......@@ -101,6 +115,20 @@ class SegmentationHead(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'],
}
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = conv_op(
kernel_size=1,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'])
self._dlv3p_norm = bn_op(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
# Segmentation head layers.
self._convs = []
self._norms = []
......@@ -121,11 +149,15 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape)
def call(self, features):
def call(self, backbone_output, decoder_output):
"""Forward pass of the segmentation head.
Args:
features: a dict of tensors
backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
......@@ -133,7 +165,20 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x = features[str(self._config_dict['level'])]
x = decoder_output[str(self._config_dict['level'])]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# deeplabv3+ feature fusion
y = backbone_output[str(
self._config_dict['low_level'])]
y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation(y)
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.concat([x, y], axis=self._bn_axis)
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
x = norm(x)
......
......@@ -30,14 +30,19 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
)
def test_forward(self, level):
head = segmentation_heads.SegmentationHead(num_classes=10, level=level)
features = {
backbone_features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
logits = head(features)
decoder_features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
self.assertAllEqual(
logits.numpy().shape,
[2, features[str(level)].shape[1], features[str(level)].shape[2], 10])
[2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2], 10])
def test_serialize_deserialize(self):
head = segmentation_heads.SegmentationHead(num_classes=10, level=3)
......
......@@ -26,7 +26,11 @@ class SegmentationModel(tf.keras.Model):
Input images are passed through backbone first. Decoder network is then
applied, and finally, segmentation head is applied on the output of the
decoder network. Layers such as ASPP should be part of decoder.
decoder network. Layers such as ASPP should be part of decoder. Any feature
fusion is done as part of the segmentation head (i.e. deeplabv3+ feature
fusion is not part of the decoder, instead it is part of the segmentation
head). This way, different feature fusion techniques can be combined with
different backbones, and decoders.
"""
def __init__(self,
......@@ -53,11 +57,14 @@ class SegmentationModel(tf.keras.Model):
self.head = head
def call(self, inputs, training=None):
features = self.backbone(inputs)
backbone_features = self.backbone(inputs)
if self.decoder:
features = self.decoder(features)
return self.head(features)
decoder_features = self.decoder(backbone_features)
else:
decoder_features = backbone_features
return self.head(backbone_features, decoder_features)
@property
def checkpoint_items(self):
......
......@@ -30,10 +30,10 @@ from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(base_task.Task):
"""A task for semantic classification."""
"""A task for semantic segmentation."""
def build_model(self):
"""Builds classification model."""
"""Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
......@@ -105,7 +105,7 @@ class SemanticSegmentationTask(base_task.Task):
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
"""Segmentation loss.
Args:
labels: labels.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册