diff --git a/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml b/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml index 638b48cf1d79eae856efae33994fb4c5b85fc9b5..b705a102c9518b4aa44c00f59166a7f08a2053c3 100644 --- a/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml +++ b/official/vision/beta/configs/experiments/image_classification/imagenet_resnet101_deeplab_tpu.yaml @@ -9,7 +9,7 @@ task: type: 'dilated_resnet' dilated_resnet: model_id: 101 - output_stride: 8 + output_stride: 16 norm_activation: activation: 'swish' losses: diff --git a/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_deeplab_tpu.yaml b/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_deeplab_tpu.yaml index f05956b59c705e68ff3246f12575ccc0bceefb2a..11bdafbc35d4c6f63625b5990de0167a78a7e6b0 100644 --- a/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_deeplab_tpu.yaml +++ b/official/vision/beta/configs/experiments/image_classification/imagenet_resnet50_deeplab_tpu.yaml @@ -9,7 +9,7 @@ task: type: 'dilated_resnet' dilated_resnet: model_id: 50 - output_stride: 8 + output_stride: 16 norm_activation: activation: 'swish' losses: diff --git a/official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml b/official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml index 5681a48efe03e3b6dd1959adc88aaaf8f227ae67..38d58f87d22c04c7647a980c872072d67364129e 100644 --- a/official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml +++ b/official/vision/beta/configs/experiments/semantic_segmentation/deeplab_resnet101_pascal_tpu.yaml @@ -1,4 +1,4 @@ -# Dilated ResNet-50 Pascal segmentation. 80.89 mean IOU. +# Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU. runtime: distribution_strategy: 'tpu' mixed_precision_dtype: 'float32' diff --git a/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml b/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e72dae16f6319edf1667150bf6657daf9b3b8da --- /dev/null +++ b/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet101_pascal_tpu.yaml @@ -0,0 +1,17 @@ +# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16. +runtime: + distribution_strategy: 'tpu' + mixed_precision_dtype: 'float32' +task: + model: + backbone: + type: 'dilated_resnet' + dilated_resnet: + model_id: 101 + output_stride: 16 + head: + feature_fusion: 'deeplabv3plus' + low_level: 2 + low_level_num_filters: 48 + init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400' + init_checkpoint_modules: 'backbone' diff --git a/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml b/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fefa86df30ba2deef19cef01a48d2e27a21c341 --- /dev/null +++ b/official/vision/beta/configs/experiments/semantic_segmentation/deeplabv3plus_resnet50_pascal_tpu.yaml @@ -0,0 +1,16 @@ +runtime: + distribution_strategy: 'tpu' + mixed_precision_dtype: 'float32' +task: + model: + backbone: + type: 'dilated_resnet' + dilated_resnet: + model_id: 50 + output_stride: 16 + head: + feature_fusion: 'deeplabv3plus' + low_level: 2 + low_level_num_filters: 48 + init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400' + init_checkpoint_modules: 'backbone' diff --git a/official/vision/beta/configs/semantic_segmentation.py b/official/vision/beta/configs/semantic_segmentation.py index 5d86a93a09c7b2572698db494dd41ec20f6dd507..fbf4992248cff75427ffa7a80e0ad7ad07013fb8 100644 --- a/official/vision/beta/configs/semantic_segmentation.py +++ b/official/vision/beta/configs/semantic_segmentation.py @@ -15,8 +15,11 @@ # ============================================================================== """Semantic segmentation configuration definition.""" import os -from typing import List, Union, Optional +from typing import List, Optional, Union + import dataclasses +import numpy as np + from official.core import exp_factory from official.modeling import hyperparams from official.modeling import optimization @@ -48,6 +51,10 @@ class SegmentationHead(hyperparams.Config): num_convs: int = 2 num_filters: int = 256 upsample_factor: int = 1 + feature_fusion: Optional[str] = None # None, or deeplabv3plus + # deeplabv3plus feature fusion params + low_level: int = 2 + low_level_num_filters: int = 48 @dataclasses.dataclass @@ -109,6 +116,9 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: train_batch_size = 16 eval_batch_size = 8 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size + output_stride = 8 + aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16 + level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( @@ -117,11 +127,99 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: input_size=[512, 512, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( - model_id=50, output_stride=8)), + model_id=50, output_stride=output_stride)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( - level=3, dilation_rates=[12, 24, 36])), - head=SegmentationHead(level=3, num_convs=0), + level=level, dilation_rates=aspp_dilation_rates)), + head=SegmentationHead(level=level, num_convs=0), + norm_activation=common.NormActivation( + activation='swish', + norm_momentum=0.9997, + norm_epsilon=1e-3, + use_sync_bn=True)), + losses=Losses(l2_weight_decay=1e-4), + train_data=DataConfig( + input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), + is_training=True, + global_batch_size=train_batch_size, + aug_scale_min=0.5, + aug_scale_max=2.0), + validation_data=DataConfig( + input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), + is_training=False, + global_batch_size=eval_batch_size, + resize_eval_groundtruth=False, + groundtruth_padded_size=[512, 512], + drop_remainder=False), + # resnet50 + init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400', + init_checkpoint_modules='backbone'), + trainer=cfg.TrainerConfig( + steps_per_loop=steps_per_epoch, + summary_interval=steps_per_epoch, + checkpoint_interval=steps_per_epoch, + train_steps=45 * steps_per_epoch, + validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, + validation_interval=steps_per_epoch, + optimizer_config=optimization.OptimizationConfig({ + 'optimizer': { + 'type': 'sgd', + 'sgd': { + 'momentum': 0.9 + } + }, + 'learning_rate': { + 'type': 'polynomial', + 'polynomial': { + 'initial_learning_rate': 0.007, + 'decay_steps': 45 * steps_per_epoch, + 'end_learning_rate': 0.0, + 'power': 0.9 + } + }, + 'warmup': { + 'type': 'linear', + 'linear': { + 'warmup_steps': 5 * steps_per_epoch, + 'warmup_learning_rate': 0 + } + } + })), + restrictions=[ + 'task.train_data.is_training != None', + 'task.validation_data.is_training != None' + ]) + + return config + + +@exp_factory.register_config_factory('seg_deeplabv3plus_pascal') +def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: + """Image segmentation on imagenet with resnet deeplabv3+.""" + train_batch_size = 16 + eval_batch_size = 8 + steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size + output_stride = 16 + aspp_dilation_rates = [6, 12, 18] # [12, 24, 36] if output_stride = 8 + level = int(np.math.log2(output_stride)) + config = cfg.ExperimentConfig( + task=SemanticSegmentationTask( + model=SemanticSegmentationModel( + num_classes=21, + input_size=[512, 512, 3], + backbone=backbones.Backbone( + type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( + model_id=50, output_stride=output_stride)), + decoder=decoders.Decoder( + type='aspp', + aspp=decoders.ASPP( + level=level, dilation_rates=aspp_dilation_rates)), + head=SegmentationHead( + level=level, + num_convs=2, + feature_fusion='deeplabv3plus', + low_level=2, + low_level_num_filters=48), norm_activation=common.NormActivation( activation='swish', norm_momentum=0.9997, diff --git a/official/vision/beta/configs/semantic_segmentation_test.py b/official/vision/beta/configs/semantic_segmentation_test.py index e3f845239f853dfff2431b7755a5bb8a17c04ad0..0cb5a7c815eb1f8b5cf625abc9e8a1e267de43ab 100644 --- a/official/vision/beta/configs/semantic_segmentation_test.py +++ b/official/vision/beta/configs/semantic_segmentation_test.py @@ -27,7 +27,8 @@ from official.vision.beta.configs import semantic_segmentation as exp_cfg class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase): - @parameterized.parameters(('seg_deeplabv3_pascal',),) + @parameterized.parameters(('seg_deeplabv3_pascal',), + ('seg_deeplabv3plus_pascal',)) def test_semantic_segmentation_configs(self, config_name): config = exp_factory.get_exp_config(config_name) self.assertIsInstance(config, cfg.ExperimentConfig) diff --git a/official/vision/beta/modeling/factory.py b/official/vision/beta/modeling/factory.py index 5dfba59152d253c035c5ff6ea6b0f97dc7e2be90..e8335cac9dbecea019dccfbbc93b99516a4b18dd 100644 --- a/official/vision/beta/modeling/factory.py +++ b/official/vision/beta/modeling/factory.py @@ -263,6 +263,9 @@ def build_segmentation_model( num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, + feature_fusion=head_config.feature_fusion, + low_level=head_config.low_level, + low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, diff --git a/official/vision/beta/modeling/heads/segmentation_heads.py b/official/vision/beta/modeling/heads/segmentation_heads.py index afa96389f3c06297810b1dd8617517f7d41f27f2..d3f10bdf6cecb0cd50519b353deb4a706ee4b813 100644 --- a/official/vision/beta/modeling/heads/segmentation_heads.py +++ b/official/vision/beta/modeling/heads/segmentation_heads.py @@ -14,7 +14,6 @@ # ============================================================================== """Segmentation heads.""" -# Import libraries import tensorflow as tf from official.modeling import tf_utils @@ -31,6 +30,9 @@ class SegmentationHead(tf.keras.layers.Layer): num_convs=2, num_filters=256, upsample_factor=1, + feature_fusion=None, + low_level=2, + low_level_num_filters=48, activation='relu', use_sync_bn=False, norm_momentum=0.99, @@ -50,6 +52,14 @@ class SegmentationHead(tf.keras.layers.Layer): Default is 256. upsample_factor: `int` number to specify the upsampling factor to generate finer mask. Default 1 means no upsampling is applied. + feature_fusion: One of `deeplabv3plus`, or None. If `deeplabv3plus`, + features from decoder_features[level] will be fused with + low level feature maps from backbone. + low_level: `int`, backbone level to be used for feature fusion. This arg + is used when feature_fusion is set to deeplabv3plus. + low_level_num_filters: `int`, reduced number of filters for the low + level features before fusing it with higher level features. This args is + only used when feature_fusion is set to deeplabv3plus. activation: `string`, indicating which activation is used, e.g. 'relu', 'swish', etc. use_sync_bn: `bool`, whether to use synchronized batch normalization @@ -63,12 +73,16 @@ class SegmentationHead(tf.keras.layers.Layer): **kwargs: other keyword arguments passed to Layer. """ super(SegmentationHead, self).__init__(**kwargs) + self._config_dict = { 'num_classes': num_classes, 'level': level, 'num_convs': num_convs, 'num_filters': num_filters, 'upsample_factor': upsample_factor, + 'feature_fusion': feature_fusion, + 'low_level': low_level, + 'low_level_num_filters': low_level_num_filters, 'activation': activation, 'use_sync_bn': use_sync_bn, 'norm_momentum': norm_momentum, @@ -101,6 +115,20 @@ class SegmentationHead(tf.keras.layers.Layer): 'epsilon': self._config_dict['norm_epsilon'], } + if self._config_dict['feature_fusion'] == 'deeplabv3plus': + # Deeplabv3+ feature fusion layers. + self._dlv3p_conv = conv_op( + kernel_size=1, + padding='same', + bias_initializer=tf.zeros_initializer(), + kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), + kernel_regularizer=self._config_dict['kernel_regularizer'], + name='segmentation_head_deeplabv3p_fusion_conv', + filters=self._config_dict['low_level_num_filters']) + + self._dlv3p_norm = bn_op( + name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs) + # Segmentation head layers. self._convs = [] self._norms = [] @@ -121,11 +149,15 @@ class SegmentationHead(tf.keras.layers.Layer): super(SegmentationHead, self).build(input_shape) - def call(self, features): + def call(self, backbone_output, decoder_output): """Forward pass of the segmentation head. Args: - features: a dict of tensors + backbone_output: a dict of tensors + - key: `str`, the level of the multilevel features. + - values: `Tensor`, the feature map tensors, whose shape is + [batch, height_l, width_l, channels]. + decoder_output: a dict of tensors - key: `str`, the level of the multilevel features. - values: `Tensor`, the feature map tensors, whose shape is [batch, height_l, width_l, channels]. @@ -133,7 +165,20 @@ class SegmentationHead(tf.keras.layers.Layer): segmentation prediction mask: `Tensor`, the segmentation mask scores predicted from input feature. """ - x = features[str(self._config_dict['level'])] + + x = decoder_output[str(self._config_dict['level'])] + + if self._config_dict['feature_fusion'] == 'deeplabv3plus': + # deeplabv3+ feature fusion + y = backbone_output[str( + self._config_dict['low_level'])] + y = self._dlv3p_norm(self._dlv3p_conv(y)) + y = self._activation(y) + + x = tf.image.resize( + x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR) + x = tf.concat([x, y], axis=self._bn_axis) + for conv, norm in zip(self._convs, self._norms): x = conv(x) x = norm(x) diff --git a/official/vision/beta/modeling/heads/segmentation_heads_test.py b/official/vision/beta/modeling/heads/segmentation_heads_test.py index 922dc2e32be6f389b8580ceb2da24fff44590bd2..31038c53c8c1fdd3bfdc5e4da2a7327420d9feff 100644 --- a/official/vision/beta/modeling/heads/segmentation_heads_test.py +++ b/official/vision/beta/modeling/heads/segmentation_heads_test.py @@ -30,14 +30,19 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ) def test_forward(self, level): head = segmentation_heads.SegmentationHead(num_classes=10, level=level) - features = { + backbone_features = { '3': np.random.rand(2, 128, 128, 16), '4': np.random.rand(2, 64, 64, 16), } - logits = head(features) + decoder_features = { + '3': np.random.rand(2, 128, 128, 16), + '4': np.random.rand(2, 64, 64, 16), + } + logits = head(backbone_features, decoder_features) self.assertAllEqual( logits.numpy().shape, - [2, features[str(level)].shape[1], features[str(level)].shape[2], 10]) + [2, decoder_features[str(level)].shape[1], + decoder_features[str(level)].shape[2], 10]) def test_serialize_deserialize(self): head = segmentation_heads.SegmentationHead(num_classes=10, level=3) diff --git a/official/vision/beta/modeling/segmentation_model.py b/official/vision/beta/modeling/segmentation_model.py index 0588748c20c75c09f238c8837936a05835b75f99..7464c1188e5f051eeaf56be16f8568143cb23397 100644 --- a/official/vision/beta/modeling/segmentation_model.py +++ b/official/vision/beta/modeling/segmentation_model.py @@ -26,7 +26,11 @@ class SegmentationModel(tf.keras.Model): Input images are passed through backbone first. Decoder network is then applied, and finally, segmentation head is applied on the output of the - decoder network. Layers such as ASPP should be part of decoder. + decoder network. Layers such as ASPP should be part of decoder. Any feature + fusion is done as part of the segmentation head (i.e. deeplabv3+ feature + fusion is not part of the decoder, instead it is part of the segmentation + head). This way, different feature fusion techniques can be combined with + different backbones, and decoders. """ def __init__(self, @@ -53,11 +57,14 @@ class SegmentationModel(tf.keras.Model): self.head = head def call(self, inputs, training=None): - features = self.backbone(inputs) + backbone_features = self.backbone(inputs) if self.decoder: - features = self.decoder(features) - return self.head(features) + decoder_features = self.decoder(backbone_features) + else: + decoder_features = backbone_features + + return self.head(backbone_features, decoder_features) @property def checkpoint_items(self): diff --git a/official/vision/beta/tasks/semantic_segmentation.py b/official/vision/beta/tasks/semantic_segmentation.py index e68d9f761b2e8dbb23aa5eae5e25be9d33992ec0..0ab4034adcf82a9b539cec035b8f22385c286c47 100644 --- a/official/vision/beta/tasks/semantic_segmentation.py +++ b/official/vision/beta/tasks/semantic_segmentation.py @@ -30,10 +30,10 @@ from official.vision.beta.modeling import factory @task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask) class SemanticSegmentationTask(base_task.Task): - """A task for semantic classification.""" + """A task for semantic segmentation.""" def build_model(self): - """Builds classification model.""" + """Builds segmentation model.""" input_specs = tf.keras.layers.InputSpec( shape=[None] + self.task_config.model.input_size) @@ -105,7 +105,7 @@ class SemanticSegmentationTask(base_task.Task): return dataset def build_losses(self, labels, model_outputs, aux_losses=None): - """Sparse categorical cross entropy loss. + """Segmentation loss. Args: labels: labels.