提交 f4477a29 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339773662
上级 9139a7b9
...@@ -9,7 +9,7 @@ task: ...@@ -9,7 +9,7 @@ task:
type: 'dilated_resnet' type: 'dilated_resnet'
dilated_resnet: dilated_resnet:
model_id: 101 model_id: 101
output_stride: 8 output_stride: 16
norm_activation: norm_activation:
activation: 'swish' activation: 'swish'
losses: losses:
......
...@@ -9,7 +9,7 @@ task: ...@@ -9,7 +9,7 @@ task:
type: 'dilated_resnet' type: 'dilated_resnet'
dilated_resnet: dilated_resnet:
model_id: 50 model_id: 50
output_stride: 8 output_stride: 16
norm_activation: norm_activation:
activation: 'swish' activation: 'swish'
losses: losses:
......
# Dilated ResNet-50 Pascal segmentation. 80.89 mean IOU. # Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU.
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32' mixed_precision_dtype: 'float32'
......
# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
# ============================================================================== # ==============================================================================
"""Semantic segmentation configuration definition.""" """Semantic segmentation configuration definition."""
import os import os
from typing import List, Union, Optional from typing import List, Optional, Union
import dataclasses import dataclasses
import numpy as np
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
...@@ -48,6 +51,10 @@ class SegmentationHead(hyperparams.Config): ...@@ -48,6 +51,10 @@ class SegmentationHead(hyperparams.Config):
num_convs: int = 2 num_convs: int = 2
num_filters: int = 256 num_filters: int = 256
upsample_factor: int = 1 upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, or deeplabv3plus
# deeplabv3plus feature fusion params
low_level: int = 2
low_level_num_filters: int = 48
@dataclasses.dataclass @dataclasses.dataclass
...@@ -109,6 +116,9 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -109,6 +116,9 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
train_batch_size = 16 train_batch_size = 16
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 8
aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=SemanticSegmentationTask( task=SemanticSegmentationTask(
model=SemanticSegmentationModel( model=SemanticSegmentationModel(
...@@ -117,11 +127,99 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -117,11 +127,99 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
input_size=[512, 512, 3], input_size=[512, 512, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=8)), model_id=50, output_stride=output_stride)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP( type='aspp', aspp=decoders.ASPP(
level=3, dilation_rates=[12, 24, 36])), level=level, dilation_rates=aspp_dilation_rates)),
head=SegmentationHead(level=3, num_convs=0), head=SegmentationHead(level=level, num_convs=0),
norm_activation=common.NormActivation(
activation='swish',
norm_momentum=0.9997,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
# resnet50
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=45 * steps_per_epoch,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.007,
'decay_steps': 45 * steps_per_epoch,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('seg_deeplabv3plus_pascal')
def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3+."""
train_batch_size = 16
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
aspp_dilation_rates = [6, 12, 18] # [12, 24, 36] if output_stride = 8
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[512, 512, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates)),
head=SegmentationHead(
level=level,
num_convs=2,
feature_fusion='deeplabv3plus',
low_level=2,
low_level_num_filters=48),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
activation='swish', activation='swish',
norm_momentum=0.9997, norm_momentum=0.9997,
......
...@@ -27,7 +27,8 @@ from official.vision.beta.configs import semantic_segmentation as exp_cfg ...@@ -27,7 +27,8 @@ from official.vision.beta.configs import semantic_segmentation as exp_cfg
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase): class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('seg_deeplabv3_pascal',),) @parameterized.parameters(('seg_deeplabv3_pascal',),
('seg_deeplabv3plus_pascal',))
def test_semantic_segmentation_configs(self, config_name): def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
......
...@@ -263,6 +263,9 @@ def build_segmentation_model( ...@@ -263,6 +263,9 @@ def build_segmentation_model(
num_convs=head_config.num_convs, num_convs=head_config.num_convs,
num_filters=head_config.num_filters, num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor, upsample_factor=head_config.upsample_factor,
feature_fusion=head_config.feature_fusion,
low_level=head_config.low_level,
low_level_num_filters=head_config.low_level_num_filters,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Segmentation heads.""" """Segmentation heads."""
# Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -31,6 +30,9 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -31,6 +30,9 @@ class SegmentationHead(tf.keras.layers.Layer):
num_convs=2, num_convs=2,
num_filters=256, num_filters=256,
upsample_factor=1, upsample_factor=1,
feature_fusion=None,
low_level=2,
low_level_num_filters=48,
activation='relu', activation='relu',
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
...@@ -50,6 +52,14 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -50,6 +52,14 @@ class SegmentationHead(tf.keras.layers.Layer):
Default is 256. Default is 256.
upsample_factor: `int` number to specify the upsampling factor to generate upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied. finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, or None. If `deeplabv3plus`,
features from decoder_features[level] will be fused with
low level feature maps from backbone.
low_level: `int`, backbone level to be used for feature fusion. This arg
is used when feature_fusion is set to deeplabv3plus.
low_level_num_filters: `int`, reduced number of filters for the low
level features before fusing it with higher level features. This args is
only used when feature_fusion is set to deeplabv3plus.
activation: `string`, indicating which activation is used, e.g. 'relu', activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc. 'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization use_sync_bn: `bool`, whether to use synchronized batch normalization
...@@ -63,12 +73,16 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -63,12 +73,16 @@ class SegmentationHead(tf.keras.layers.Layer):
**kwargs: other keyword arguments passed to Layer. **kwargs: other keyword arguments passed to Layer.
""" """
super(SegmentationHead, self).__init__(**kwargs) super(SegmentationHead, self).__init__(**kwargs)
self._config_dict = { self._config_dict = {
'num_classes': num_classes, 'num_classes': num_classes,
'level': level, 'level': level,
'num_convs': num_convs, 'num_convs': num_convs,
'num_filters': num_filters, 'num_filters': num_filters,
'upsample_factor': upsample_factor, 'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
'low_level': low_level,
'low_level_num_filters': low_level_num_filters,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
...@@ -101,6 +115,20 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -101,6 +115,20 @@ class SegmentationHead(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
} }
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = conv_op(
kernel_size=1,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
filters=self._config_dict['low_level_num_filters'])
self._dlv3p_norm = bn_op(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
# Segmentation head layers. # Segmentation head layers.
self._convs = [] self._convs = []
self._norms = [] self._norms = []
...@@ -121,11 +149,15 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -121,11 +149,15 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape) super(SegmentationHead, self).build(input_shape)
def call(self, features): def call(self, backbone_output, decoder_output):
"""Forward pass of the segmentation head. """Forward pass of the segmentation head.
Args: Args:
features: a dict of tensors backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features. - key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is - values: `Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels]. [batch, height_l, width_l, channels].
...@@ -133,7 +165,20 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -133,7 +165,20 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: `Tensor`, the segmentation mask scores segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature. predicted from input feature.
""" """
x = features[str(self._config_dict['level'])]
x = decoder_output[str(self._config_dict['level'])]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# deeplabv3+ feature fusion
y = backbone_output[str(
self._config_dict['low_level'])]
y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation(y)
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.concat([x, y], axis=self._bn_axis)
for conv, norm in zip(self._convs, self._norms): for conv, norm in zip(self._convs, self._norms):
x = conv(x) x = conv(x)
x = norm(x) x = norm(x)
......
...@@ -30,14 +30,19 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -30,14 +30,19 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
) )
def test_forward(self, level): def test_forward(self, level):
head = segmentation_heads.SegmentationHead(num_classes=10, level=level) head = segmentation_heads.SegmentationHead(num_classes=10, level=level)
features = { backbone_features = {
'3': np.random.rand(2, 128, 128, 16), '3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16), '4': np.random.rand(2, 64, 64, 16),
} }
logits = head(features) decoder_features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
self.assertAllEqual( self.assertAllEqual(
logits.numpy().shape, logits.numpy().shape,
[2, features[str(level)].shape[1], features[str(level)].shape[2], 10]) [2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2], 10])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
head = segmentation_heads.SegmentationHead(num_classes=10, level=3) head = segmentation_heads.SegmentationHead(num_classes=10, level=3)
......
...@@ -26,7 +26,11 @@ class SegmentationModel(tf.keras.Model): ...@@ -26,7 +26,11 @@ class SegmentationModel(tf.keras.Model):
Input images are passed through backbone first. Decoder network is then Input images are passed through backbone first. Decoder network is then
applied, and finally, segmentation head is applied on the output of the applied, and finally, segmentation head is applied on the output of the
decoder network. Layers such as ASPP should be part of decoder. decoder network. Layers such as ASPP should be part of decoder. Any feature
fusion is done as part of the segmentation head (i.e. deeplabv3+ feature
fusion is not part of the decoder, instead it is part of the segmentation
head). This way, different feature fusion techniques can be combined with
different backbones, and decoders.
""" """
def __init__(self, def __init__(self,
...@@ -53,11 +57,14 @@ class SegmentationModel(tf.keras.Model): ...@@ -53,11 +57,14 @@ class SegmentationModel(tf.keras.Model):
self.head = head self.head = head
def call(self, inputs, training=None): def call(self, inputs, training=None):
features = self.backbone(inputs) backbone_features = self.backbone(inputs)
if self.decoder: if self.decoder:
features = self.decoder(features) decoder_features = self.decoder(backbone_features)
return self.head(features) else:
decoder_features = backbone_features
return self.head(backbone_features, decoder_features)
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
...@@ -30,10 +30,10 @@ from official.vision.beta.modeling import factory ...@@ -30,10 +30,10 @@ from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask) @task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
class SemanticSegmentationTask(base_task.Task): class SemanticSegmentationTask(base_task.Task):
"""A task for semantic classification.""" """A task for semantic segmentation."""
def build_model(self): def build_model(self):
"""Builds classification model.""" """Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size) shape=[None] + self.task_config.model.input_size)
...@@ -105,7 +105,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -105,7 +105,7 @@ class SemanticSegmentationTask(base_task.Task):
return dataset return dataset
def build_losses(self, labels, model_outputs, aux_losses=None): def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss. """Segmentation loss.
Args: Args:
labels: labels. labels: labels.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册