diff --git a/official/vision/modeling/backbones/efficientnet.py b/official/vision/modeling/backbones/efficientnet.py index bece521d3acafe370eb29a9e711d3f4fc15bd241..2dc7ba34c7bf230668a5e57edd62adb4a06aaaed 100644 --- a/official/vision/modeling/backbones/efficientnet.py +++ b/official/vision/modeling/backbones/efficientnet.py @@ -151,10 +151,7 @@ class EfficientNet(tf.keras.Model): self._norm_epsilon = norm_epsilon self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 @@ -178,7 +175,10 @@ class EfficientNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( inputs) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation)(x) @@ -210,7 +210,10 @@ class EfficientNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( x) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) endpoints[str(endpoint_level)] = tf_utils.get_activation(activation)(x) diff --git a/official/vision/modeling/backbones/mobilenet.py b/official/vision/modeling/backbones/mobilenet.py index 0d695317c514048e42167437108711ee5902e55e..279eb917966f8c1f79d2899f799e3dee7775e565 100644 --- a/official/vision/modeling/backbones/mobilenet.py +++ b/official/vision/modeling/backbones/mobilenet.py @@ -91,15 +91,12 @@ class Conv2DBNBlock(tf.keras.layers.Layer): self._use_sync_bn = use_sync_bn self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon + self._norm = tf.keras.layers.BatchNormalization if use_explicit_padding and kernel_size > 1: self._padding = 'valid' else: self._padding = 'same' - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -141,7 +138,8 @@ class Conv2DBNBlock(tf.keras.layers.Layer): self._norm0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn) self._activation_layer = tf_utils.get_activation( self._activation, use_keras_layer=True) diff --git a/official/vision/modeling/backbones/resnet.py b/official/vision/modeling/backbones/resnet.py index 0a2b07971522001243b5cef7f0756762955a4621..197471be2c40cbe178ea67dd40d566458a4fdf25 100644 --- a/official/vision/modeling/backbones/resnet.py +++ b/official/vision/modeling/backbones/resnet.py @@ -179,10 +179,7 @@ class ResNet(tf.keras.Model): self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer @@ -238,6 +235,7 @@ class ResNet(tf.keras.Model): momentum=self._norm_momentum, epsilon=self._norm_epsilon, trainable=self._bn_trainable, + synchronized=self._use_sync_bn, )(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) elif self._stem_type == 'v1': @@ -256,6 +254,7 @@ class ResNet(tf.keras.Model): momentum=self._norm_momentum, epsilon=self._norm_epsilon, trainable=self._bn_trainable, + synchronized=self._use_sync_bn, )(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = layers.Conv2D( @@ -273,6 +272,7 @@ class ResNet(tf.keras.Model): momentum=self._norm_momentum, epsilon=self._norm_epsilon, trainable=self._bn_trainable, + synchronized=self._use_sync_bn, )(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = layers.Conv2D( @@ -290,6 +290,7 @@ class ResNet(tf.keras.Model): momentum=self._norm_momentum, epsilon=self._norm_epsilon, trainable=self._bn_trainable, + synchronized=self._use_sync_bn, )(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) else: @@ -311,6 +312,7 @@ class ResNet(tf.keras.Model): momentum=self._norm_momentum, epsilon=self._norm_epsilon, trainable=self._bn_trainable, + synchronized=self._use_sync_bn, )(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) else: diff --git a/official/vision/modeling/backbones/resnet_3d.py b/official/vision/modeling/backbones/resnet_3d.py index 2b323c73e557287ced996633acb608fa9e5348e6..0d51f3ca20e17e765ce1fc388ae326ce7d5046dc 100644 --- a/official/vision/modeling/backbones/resnet_3d.py +++ b/official/vision/modeling/backbones/resnet_3d.py @@ -145,10 +145,7 @@ class ResNet3D(tf.keras.Model): self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer @@ -232,7 +229,8 @@ class ResNet3D(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)(x) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)(x) x = tf_utils.get_activation(self._activation)(x) elif stem_type == 'v1': x = layers.Conv3D( @@ -248,7 +246,8 @@ class ResNet3D(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)(x) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)(x) x = tf_utils.get_activation(self._activation)(x) x = layers.Conv3D( filters=32, @@ -263,7 +262,8 @@ class ResNet3D(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)(x) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)(x) x = tf_utils.get_activation(self._activation)(x) x = layers.Conv3D( filters=64, @@ -278,7 +278,8 @@ class ResNet3D(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)(x) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)(x) x = tf_utils.get_activation(self._activation)(x) else: raise ValueError(f'Stem type {stem_type} not supported.') diff --git a/official/vision/modeling/backbones/resnet_deeplab.py b/official/vision/modeling/backbones/resnet_deeplab.py index 16a6c71c51b7127ac379bbbd8394482720f52937..cbcf861030e273a9d89f72f170f53d8e3979099d 100644 --- a/official/vision/modeling/backbones/resnet_deeplab.py +++ b/official/vision/modeling/backbones/resnet_deeplab.py @@ -126,10 +126,7 @@ class DilatedResNet(tf.keras.Model): self._activation = activation self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer @@ -159,7 +156,10 @@ class DilatedResNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( inputs) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation)(x) elif stem_type == 'v1': @@ -174,7 +174,10 @@ class DilatedResNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( inputs) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D( @@ -188,7 +191,10 @@ class DilatedResNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( x) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation)(x) x = layers.Conv2D( @@ -202,7 +208,10 @@ class DilatedResNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( x) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation)(x) else: @@ -220,7 +229,10 @@ class DilatedResNet(tf.keras.Model): bias_regularizer=self._bias_regularizer)( x) x = self._norm( - axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( + axis=bn_axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)( x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x) else: diff --git a/official/vision/modeling/backbones/revnet.py b/official/vision/modeling/backbones/revnet.py index 1fa75f42721df56392b05b53686427787490ca10..1f0a864611c37d9f90ad163bc705667b2dabfa15 100644 --- a/official/vision/modeling/backbones/revnet.py +++ b/official/vision/modeling/backbones/revnet.py @@ -93,10 +93,7 @@ class RevNet(tf.keras.Model): self._norm_epsilon = norm_epsilon self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization + self._norm = tf.keras.layers.BatchNormalization axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1 @@ -109,7 +106,10 @@ class RevNet(tf.keras.Model): kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer)(inputs) x = self._norm( - axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) + axis=axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn)(x) x = tf_utils.get_activation(activation)(x) x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) diff --git a/official/vision/modeling/backbones/spinenet.py b/official/vision/modeling/backbones/spinenet.py index 3c5f8ec362b9404023a5cec32cf89653b25aa3f5..cec3a3719e3e14f31719011ff4515c5939f5fcaf 100644 --- a/official/vision/modeling/backbones/spinenet.py +++ b/official/vision/modeling/backbones/spinenet.py @@ -205,11 +205,7 @@ class SpineNet(tf.keras.Model): self._num_init_blocks = 2 self._set_activation_fn(activation) - - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 @@ -303,7 +299,8 @@ class SpineNet(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation_fn)(x) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) @@ -434,7 +431,8 @@ class SpineNet(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation_fn)(x) endpoints[str(level)] = x @@ -466,7 +464,8 @@ class SpineNet(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation_fn)(x) @@ -485,7 +484,8 @@ class SpineNet(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation_fn)(x) input_width /= 2 @@ -511,7 +511,8 @@ class SpineNet(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) return x diff --git a/official/vision/modeling/backbones/spinenet_mobile.py b/official/vision/modeling/backbones/spinenet_mobile.py index 563f2b75770f3c1a498ca9182bc125ca8490141f..5009b3e2a08ff1629ed0ec416181c7a8d2ee3289 100644 --- a/official/vision/modeling/backbones/spinenet_mobile.py +++ b/official/vision/modeling/backbones/spinenet_mobile.py @@ -205,11 +205,7 @@ class SpineNetMobile(tf.keras.Model): self._norm_epsilon = norm_epsilon self._use_keras_upsampling_2d = use_keras_upsampling_2d self._num_init_blocks = 2 - - if use_sync_bn: - self._norm = layers.experimental.SyncBatchNormalization - else: - self._norm = layers.BatchNormalization + self._norm = layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 @@ -290,7 +286,8 @@ class SpineNetMobile(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) @@ -428,7 +425,8 @@ class SpineNetMobile(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) endpoints[str(level)] = x @@ -453,7 +451,8 @@ class SpineNetMobile(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) x = tf_utils.get_activation( self._activation, use_keras_layer=True)(x) @@ -476,7 +475,8 @@ class SpineNetMobile(tf.keras.Model): x = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon)( + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn)( x) return x diff --git a/official/vision/modeling/classification_model.py b/official/vision/modeling/classification_model.py index 95d8858c7dfbee8c8849ee7d0225485341104dd5..c1bcdecf2463d07d36356d8409870d9bfdf0f6dc 100644 --- a/official/vision/modeling/classification_model.py +++ b/official/vision/modeling/classification_model.py @@ -62,10 +62,7 @@ class ClassificationModel(tf.keras.Model): skip_logits_layer: `bool`, whether to skip the prediction layer. **kwargs: keyword arguments to be passed. """ - if use_sync_bn: - norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - norm = tf.keras.layers.BatchNormalization + norm = tf.keras.layers.BatchNormalization axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1 inputs = tf.keras.Input(shape=input_specs.shape[1:], name=input_specs.name) @@ -73,7 +70,12 @@ class ClassificationModel(tf.keras.Model): x = endpoints[max(endpoints.keys())] if add_head_batch_norm: - x = norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) + x = norm( + axis=axis, + momentum=norm_momentum, + epsilon=norm_epsilon, + synchronized=use_sync_bn, + )(x) # Depending on the backbone type, backbone's output can be # [batch_size, height, weight, channel_size] or diff --git a/official/vision/modeling/decoders/fpn.py b/official/vision/modeling/decoders/fpn.py index fa8a38114c15a436711ebea56c79baa1a3c9d3a2..16c434366679f04011777a3f31d1948609b9a5ad 100644 --- a/official/vision/modeling/decoders/fpn.py +++ b/official/vision/modeling/decoders/fpn.py @@ -97,10 +97,7 @@ class FPN(tf.keras.Model): conv2d = tf.keras.layers.SeparableConv2D else: conv2d = tf.keras.layers.Conv2D - if use_sync_bn: - norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - norm = tf.keras.layers.BatchNormalization + norm = tf.keras.layers.BatchNormalization activation_fn = tf_utils.get_activation(activation, use_keras_layer=True) # Build input feature pyramid. @@ -185,6 +182,7 @@ class FPN(tf.keras.Model): axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon, + synchronized=use_sync_bn, name=f'norm_{level}')( feats[str(level)]) diff --git a/official/vision/modeling/decoders/nasfpn.py b/official/vision/modeling/decoders/nasfpn.py index 991421bc21300fcbbead218aea3b813800b775f5..c3359ed10ae79366bdf39131e0f3b23aee84835d 100644 --- a/official/vision/modeling/decoders/nasfpn.py +++ b/official/vision/modeling/decoders/nasfpn.py @@ -137,9 +137,7 @@ class NASFPN(tf.keras.Model): self._conv_op = (tf.keras.layers.SeparableConv2D if self._config_dict['use_separable_conv'] else tf.keras.layers.Conv2D) - self._norm_op = (tf.keras.layers.experimental.SyncBatchNormalization - if self._config_dict['use_sync_bn'] - else tf.keras.layers.BatchNormalization) + self._norm_op = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -148,6 +146,7 @@ class NASFPN(tf.keras.Model): 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], + 'synchronized': self._config_dict['use_sync_bn'], } self._activation = tf_utils.get_activation(activation) diff --git a/official/vision/modeling/heads/instance_heads.py b/official/vision/modeling/heads/instance_heads.py index 4723bb64df58667f2db9df0ed90eb3e4c1fb7e78..9062bcbf69b955f0aabd7dbaf98c4165ee7efae5 100644 --- a/official/vision/modeling/heads/instance_heads.py +++ b/official/vision/modeling/heads/instance_heads.py @@ -120,13 +120,12 @@ class DetectionHead(tf.keras.layers.Layer): 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) - bn_op = (tf.keras.layers.experimental.SyncBatchNormalization - if self._config_dict['use_sync_bn'] - else tf.keras.layers.BatchNormalization) + bn_op = tf.keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], + 'synchronized': self._config_dict['use_sync_bn'], } self._convs = [] @@ -314,13 +313,12 @@ class MaskHead(tf.keras.layers.Layer): 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) - bn_op = (tf.keras.layers.experimental.SyncBatchNormalization - if self._config_dict['use_sync_bn'] - else tf.keras.layers.BatchNormalization) + bn_op = tf.keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], + 'synchronized': self._config_dict['use_sync_bn'], } self._convs = [] diff --git a/official/vision/modeling/heads/segmentation_heads.py b/official/vision/modeling/heads/segmentation_heads.py index 22a0371ebe27042d6c83504ef6926a5d1443da13..b33d0b6f7daee509e0a68603331f74500a378fd3 100644 --- a/official/vision/modeling/heads/segmentation_heads.py +++ b/official/vision/modeling/heads/segmentation_heads.py @@ -108,13 +108,12 @@ class MaskScoring(tf.keras.Model): 'kernel_regularizer': self._config_dict['kernel_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'], }) - bn_op = (tf.keras.layers.experimental.SyncBatchNormalization - if self._config_dict['use_sync_bn'] - else tf.keras.layers.BatchNormalization) + bn_op = tf.keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], + 'synchronized': self._config_dict['use_sync_bn'], } self._convs = [] @@ -324,13 +323,12 @@ class SegmentationHead(tf.keras.layers.Layer): """Creates the variables of the segmentation head.""" use_depthwise_convolution = self._config_dict['use_depthwise_convolution'] conv_op = tf.keras.layers.Conv2D - bn_op = (tf.keras.layers.experimental.SyncBatchNormalization - if self._config_dict['use_sync_bn'] - else tf.keras.layers.BatchNormalization) + bn_op = tf.keras.layers.BatchNormalization bn_kwargs = { 'axis': self._bn_axis, 'momentum': self._config_dict['norm_momentum'], 'epsilon': self._config_dict['norm_epsilon'], + 'synchronized': self._config_dict['use_sync_bn'], } if self._config_dict['feature_fusion'] in {'deeplabv3plus', diff --git a/official/vision/modeling/layers/deeplab.py b/official/vision/modeling/layers/deeplab.py index adb2d7d791e4e0c3fcde7896e005cede7e4ae421..23a77553a2193219d96b1e6c64afe826e0fbd08b 100644 --- a/official/vision/modeling/layers/deeplab.py +++ b/official/vision/modeling/layers/deeplab.py @@ -90,11 +90,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): channels = input_shape[3] self.aspp_layers = [] - - if self.use_sync_bn: - bn_op = tf.keras.layers.experimental.SyncBatchNormalization - else: - bn_op = tf.keras.layers.BatchNormalization + bn_op = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': bn_axis = -1 @@ -112,7 +108,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, - epsilon=self.batchnorm_epsilon), + epsilon=self.batchnorm_epsilon, + synchronized=self.use_sync_bn), tf.keras.layers.Activation(self.activation) ]) self.aspp_layers.append(conv_sequential) @@ -143,7 +140,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, - epsilon=self.batchnorm_epsilon), + epsilon=self.batchnorm_epsilon, + synchronized=self.use_sync_bn), tf.keras.layers.Activation(self.activation) ]) self.aspp_layers.append(conv_sequential) @@ -168,7 +166,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, - epsilon=self.batchnorm_epsilon), + epsilon=self.batchnorm_epsilon, + synchronized=self.use_sync_bn), tf.keras.layers.Activation(self.activation) ])) @@ -185,7 +184,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): bn_op( axis=bn_axis, momentum=self.batchnorm_momentum, - epsilon=self.batchnorm_epsilon), + epsilon=self.batchnorm_epsilon, + synchronized=self.use_sync_bn), tf.keras.layers.Activation(self.activation), tf.keras.layers.Dropout(rate=self.dropout) ]) diff --git a/official/vision/modeling/layers/nn_blocks.py b/official/vision/modeling/layers/nn_blocks.py index 9a81588e31ebd37006dbb4fb8aa792665c993913..712eaaf0faf1a01fe02040dc81c1fd8cb2b7c437 100644 --- a/official/vision/modeling/layers/nn_blocks.py +++ b/official/vision/modeling/layers/nn_blocks.py @@ -123,11 +123,8 @@ class ResidualBlock(tf.keras.layers.Layer): self._norm_epsilon = norm_epsilon self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer + self._norm = tf.keras.layers.BatchNormalization - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -150,7 +147,9 @@ class ResidualBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) conv1_padding = 'same' # explicit padding here is added for centernet @@ -171,7 +170,9 @@ class ResidualBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) self._conv2 = tf.keras.layers.Conv2D( filters=self._filters, @@ -186,7 +187,9 @@ class ResidualBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: self._squeeze_excitation = nn_layers.SqueezeExcitation( @@ -321,10 +324,8 @@ class BottleneckBlock(tf.keras.layers.Layer): self._norm_epsilon = norm_epsilon self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization + self._norm = tf.keras.layers.BatchNormalization + if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -360,7 +361,9 @@ class BottleneckBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) self._conv1 = tf.keras.layers.Conv2D( filters=self._filters, @@ -374,7 +377,9 @@ class BottleneckBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) self._activation1 = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -392,7 +397,9 @@ class BottleneckBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) self._activation2 = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -408,7 +415,9 @@ class BottleneckBlock(tf.keras.layers.Layer): axis=self._bn_axis, momentum=self._norm_momentum, epsilon=self._norm_epsilon, - trainable=self._bn_trainable) + trainable=self._bn_trainable, + synchronized=self._use_sync_bn, + ) self._activation3 = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -589,11 +598,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): self._bias_regularizer = bias_regularizer self._expand_se_in_filters = expand_se_in_filters self._output_intermediate_endpoints = output_intermediate_endpoints + self._norm = tf.keras.layers.BatchNormalization - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -628,7 +634,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): self._norm0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._activation_layer = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -648,7 +656,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): self._norm1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._depthwise_activation_layer = tf_utils.get_activation( self._depthwise_activation, use_keras_layer=True) @@ -686,7 +696,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): self._norm2 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) if self._stochastic_depth_drop_rate: self._stochastic_depth = nn_layers.StochasticDepth( @@ -812,11 +824,7 @@ class ResidualInner(tf.keras.layers.Layer): self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon self._batch_norm_first = batch_norm_first - - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization + self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 @@ -829,7 +837,9 @@ class ResidualInner(tf.keras.layers.Layer): self._batch_norm_0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv2d_1 = tf.keras.layers.Conv2D( filters=self.filters, @@ -843,7 +853,9 @@ class ResidualInner(tf.keras.layers.Layer): self._batch_norm_1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv2d_2 = tf.keras.layers.Conv2D( filters=self.filters, @@ -938,11 +950,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer): self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon self._batch_norm_first = batch_norm_first - - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization + self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 @@ -955,7 +963,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): self._batch_norm_0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv2d_1 = tf.keras.layers.Conv2D( filters=self.filters, kernel_size=1, @@ -967,7 +977,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): self._batch_norm_1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv2d_2 = tf.keras.layers.Conv2D( filters=self.filters, kernel_size=3, @@ -979,7 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): self._batch_norm_2 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv2d_3 = tf.keras.layers.Conv2D( filters=self.filters * 4, kernel_size=1, @@ -1257,11 +1271,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): self._use_sync_bn = use_sync_bn self._norm_momentum = norm_momentum self._norm_epsilon = norm_epsilon + self._norm = tf.keras.layers.BatchNormalization - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -1301,7 +1312,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): self._norm0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._conv1 = tf.keras.layers.Conv2D( filters=self._filters, @@ -1314,7 +1327,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): self._norm1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) super(DepthwiseSeparableConvBlock, self).build(input_shape) @@ -1398,11 +1413,8 @@ class TuckerConvBlock(tf.keras.layers.Layer): self._norm_epsilon = norm_epsilon self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer + self._norm = tf.keras.layers.BatchNormalization - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -1426,7 +1438,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): self._norm0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._activation_layer0 = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -1447,7 +1461,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): self._norm1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) self._activation_layer1 = tf_utils.get_activation( self._activation, use_keras_layer=True) @@ -1464,7 +1480,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): self._norm2 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn, + ) if self._stochastic_depth_drop_rate: self._stochastic_depth = nn_layers.StochasticDepth( diff --git a/official/vision/modeling/layers/nn_blocks_3d.py b/official/vision/modeling/layers/nn_blocks_3d.py index 66087c3671055f7f8c723931e1c183e93541e7b3..6940cc9c82cf087bddb15d0893fb2e21ec8ba165 100644 --- a/official/vision/modeling/layers/nn_blocks_3d.py +++ b/official/vision/modeling/layers/nn_blocks_3d.py @@ -130,11 +130,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): self._norm_epsilon = norm_epsilon self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer + self._norm = tf.keras.layers.BatchNormalization - if use_sync_bn: - self._norm = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._norm = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 else: @@ -161,7 +158,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): self._norm0 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn) self._temporal_conv = tf.keras.layers.Conv3D( filters=self._filters, @@ -175,7 +173,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): self._norm1 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn) self._spatial_conv = tf.keras.layers.Conv3D( filters=self._filters, @@ -189,7 +188,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): self._norm2 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn) self._expand_conv = tf.keras.layers.Conv3D( filters=4 * self._filters, @@ -203,7 +203,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): self._norm3 = self._norm( axis=self._bn_axis, momentum=self._norm_momentum, - epsilon=self._norm_epsilon) + epsilon=self._norm_epsilon, + synchronized=self._use_sync_bn) if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: self._squeeze_excitation = nn_layers.SqueezeExcitation( diff --git a/official/vision/modeling/layers/nn_layers.py b/official/vision/modeling/layers/nn_layers.py index a82db65c121fefaba489ea17cff739596b9e3704..bbb96998a78c300344bd02012c481b9f9bcf8eee 100644 --- a/official/vision/modeling/layers/nn_layers.py +++ b/official/vision/modeling/layers/nn_layers.py @@ -1135,10 +1135,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): self._pool_kernel_size = pool_kernel_size self._use_depthwise_convolution = use_depthwise_convolution self._activation_fn = tf_utils.get_activation(activation) - if self._use_sync_bn: - self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization - else: - self._bn_op = tf.keras.layers.BatchNormalization + self._bn_op = tf.keras.layers.BatchNormalization if tf.keras.backend.image_data_format() == 'channels_last': self._bn_axis = -1 @@ -1161,7 +1158,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): norm1 = self._bn_op( axis=self._bn_axis, momentum=self._batchnorm_momentum, - epsilon=self._batchnorm_epsilon) + epsilon=self._batchnorm_epsilon, + synchronized=self._use_sync_bn) self.aspp_layers.append([conv1, norm1]) @@ -1195,7 +1193,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): norm_dilation = self._bn_op( axis=self._bn_axis, momentum=self._batchnorm_momentum, - epsilon=self._batchnorm_epsilon) + epsilon=self._batchnorm_epsilon, + synchronized=self._use_sync_bn) self.aspp_layers.append(conv_dilation + [norm_dilation]) @@ -1216,7 +1215,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): norm2 = self._bn_op( axis=self._bn_axis, momentum=self._batchnorm_momentum, - epsilon=self._batchnorm_epsilon) + epsilon=self._batchnorm_epsilon, + synchronized=self._use_sync_bn) self.aspp_layers.append(pooling + [conv2, norm2]) @@ -1234,7 +1234,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): self._bn_op( axis=self._bn_axis, momentum=self._batchnorm_momentum, - epsilon=self._batchnorm_epsilon) + epsilon=self._batchnorm_epsilon, + synchronized=self._use_sync_bn) ] self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._concat_layer = tf.keras.layers.Concatenate(axis=-1)