提交 350243e2 编写于 作者: A A. Unique TensorFlower

Remove Deprecated `tf.keras.layers.experimental.SyncBatchNormalization` from...

Remove Deprecated `tf.keras.layers.experimental.SyncBatchNormalization` from existing code and update `tf.keras.layers.experimental.BatchNormalization` with `synchronized` argument.

PiperOrigin-RevId: 526102034
上级 3b4caa8f
...@@ -151,10 +151,7 @@ class EfficientNet(tf.keras.Model): ...@@ -151,10 +151,7 @@ class EfficientNet(tf.keras.Model):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if use_sync_bn: self._norm = layers.BatchNormalization
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 bn_axis = -1
...@@ -178,7 +175,10 @@ class EfficientNet(tf.keras.Model): ...@@ -178,7 +175,10 @@ class EfficientNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
...@@ -210,7 +210,10 @@ class EfficientNet(tf.keras.Model): ...@@ -210,7 +210,10 @@ class EfficientNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
endpoints[str(endpoint_level)] = tf_utils.get_activation(activation)(x) endpoints[str(endpoint_level)] = tf_utils.get_activation(activation)(x)
......
...@@ -91,15 +91,12 @@ class Conv2DBNBlock(tf.keras.layers.Layer): ...@@ -91,15 +91,12 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._norm = tf.keras.layers.BatchNormalization
if use_explicit_padding and kernel_size > 1: if use_explicit_padding and kernel_size > 1:
self._padding = 'valid' self._padding = 'valid'
else: else:
self._padding = 'same' self._padding = 'same'
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -141,7 +138,8 @@ class Conv2DBNBlock(tf.keras.layers.Layer): ...@@ -141,7 +138,8 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._activation_layer = tf_utils.get_activation( self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
......
...@@ -179,10 +179,7 @@ class ResNet(tf.keras.Model): ...@@ -179,10 +179,7 @@ class ResNet(tf.keras.Model):
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if use_sync_bn: self._norm = layers.BatchNormalization
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
...@@ -238,6 +235,7 @@ class ResNet(tf.keras.Model): ...@@ -238,6 +235,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable, trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x) )(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
elif self._stem_type == 'v1': elif self._stem_type == 'v1':
...@@ -256,6 +254,7 @@ class ResNet(tf.keras.Model): ...@@ -256,6 +254,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable, trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x) )(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -273,6 +272,7 @@ class ResNet(tf.keras.Model): ...@@ -273,6 +272,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable, trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x) )(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -290,6 +290,7 @@ class ResNet(tf.keras.Model): ...@@ -290,6 +290,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable, trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x) )(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
else: else:
...@@ -311,6 +312,7 @@ class ResNet(tf.keras.Model): ...@@ -311,6 +312,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable, trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x) )(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
else: else:
......
...@@ -145,10 +145,7 @@ class ResNet3D(tf.keras.Model): ...@@ -145,10 +145,7 @@ class ResNet3D(tf.keras.Model):
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if use_sync_bn: self._norm = layers.BatchNormalization
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
...@@ -232,7 +229,8 @@ class ResNet3D(tf.keras.Model): ...@@ -232,7 +229,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x) x = tf_utils.get_activation(self._activation)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
x = layers.Conv3D( x = layers.Conv3D(
...@@ -248,7 +246,8 @@ class ResNet3D(tf.keras.Model): ...@@ -248,7 +246,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x) x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D( x = layers.Conv3D(
filters=32, filters=32,
...@@ -263,7 +262,8 @@ class ResNet3D(tf.keras.Model): ...@@ -263,7 +262,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x) x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D( x = layers.Conv3D(
filters=64, filters=64,
...@@ -278,7 +278,8 @@ class ResNet3D(tf.keras.Model): ...@@ -278,7 +278,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x) x = tf_utils.get_activation(self._activation)(x)
else: else:
raise ValueError(f'Stem type {stem_type} not supported.') raise ValueError(f'Stem type {stem_type} not supported.')
......
...@@ -126,10 +126,7 @@ class DilatedResNet(tf.keras.Model): ...@@ -126,10 +126,7 @@ class DilatedResNet(tf.keras.Model):
self._activation = activation self._activation = activation
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if use_sync_bn: self._norm = layers.BatchNormalization
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
...@@ -159,7 +156,10 @@ class DilatedResNet(tf.keras.Model): ...@@ -159,7 +156,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
...@@ -174,7 +174,10 @@ class DilatedResNet(tf.keras.Model): ...@@ -174,7 +174,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -188,7 +191,10 @@ class DilatedResNet(tf.keras.Model): ...@@ -188,7 +191,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D( x = layers.Conv2D(
...@@ -202,7 +208,10 @@ class DilatedResNet(tf.keras.Model): ...@@ -202,7 +208,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
else: else:
...@@ -220,7 +229,10 @@ class DilatedResNet(tf.keras.Model): ...@@ -220,7 +229,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x) x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x) x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else: else:
......
...@@ -93,10 +93,7 @@ class RevNet(tf.keras.Model): ...@@ -93,10 +93,7 @@ class RevNet(tf.keras.Model):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
if use_sync_bn: self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1 axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
...@@ -109,7 +106,10 @@ class RevNet(tf.keras.Model): ...@@ -109,7 +106,10 @@ class RevNet(tf.keras.Model):
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer)(inputs) kernel_regularizer=self._kernel_regularizer)(inputs)
x = self._norm( x = self._norm(
axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) axis=axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(x)
x = tf_utils.get_activation(activation)(x) x = tf_utils.get_activation(activation)(x)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
......
...@@ -205,11 +205,7 @@ class SpineNet(tf.keras.Model): ...@@ -205,11 +205,7 @@ class SpineNet(tf.keras.Model):
self._num_init_blocks = 2 self._num_init_blocks = 2
self._set_activation_fn(activation) self._set_activation_fn(activation)
self._norm = layers.BatchNormalization
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -303,7 +299,8 @@ class SpineNet(tf.keras.Model): ...@@ -303,7 +299,8 @@ class SpineNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation_fn)(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
...@@ -434,7 +431,8 @@ class SpineNet(tf.keras.Model): ...@@ -434,7 +431,8 @@ class SpineNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation_fn)(x)
endpoints[str(level)] = x endpoints[str(level)] = x
...@@ -466,7 +464,8 @@ class SpineNet(tf.keras.Model): ...@@ -466,7 +464,8 @@ class SpineNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation_fn)(x)
...@@ -485,7 +484,8 @@ class SpineNet(tf.keras.Model): ...@@ -485,7 +484,8 @@ class SpineNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation_fn)(x)
input_width /= 2 input_width /= 2
...@@ -511,7 +511,8 @@ class SpineNet(tf.keras.Model): ...@@ -511,7 +511,8 @@ class SpineNet(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
return x return x
......
...@@ -205,11 +205,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -205,11 +205,7 @@ class SpineNetMobile(tf.keras.Model):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._use_keras_upsampling_2d = use_keras_upsampling_2d self._use_keras_upsampling_2d = use_keras_upsampling_2d
self._num_init_blocks = 2 self._num_init_blocks = 2
self._norm = layers.BatchNormalization
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -290,7 +286,8 @@ class SpineNetMobile(tf.keras.Model): ...@@ -290,7 +286,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
...@@ -428,7 +425,8 @@ class SpineNetMobile(tf.keras.Model): ...@@ -428,7 +425,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
endpoints[str(level)] = x endpoints[str(level)] = x
...@@ -453,7 +451,8 @@ class SpineNetMobile(tf.keras.Model): ...@@ -453,7 +451,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
x = tf_utils.get_activation( x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(x) self._activation, use_keras_layer=True)(x)
...@@ -476,7 +475,8 @@ class SpineNetMobile(tf.keras.Model): ...@@ -476,7 +475,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm( x = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x) x)
return x return x
......
...@@ -62,10 +62,7 @@ class ClassificationModel(tf.keras.Model): ...@@ -62,10 +62,7 @@ class ClassificationModel(tf.keras.Model):
skip_logits_layer: `bool`, whether to skip the prediction layer. skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
if use_sync_bn: norm = tf.keras.layers.BatchNormalization
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1 axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
inputs = tf.keras.Input(shape=input_specs.shape[1:], name=input_specs.name) inputs = tf.keras.Input(shape=input_specs.shape[1:], name=input_specs.name)
...@@ -73,7 +70,12 @@ class ClassificationModel(tf.keras.Model): ...@@ -73,7 +70,12 @@ class ClassificationModel(tf.keras.Model):
x = endpoints[max(endpoints.keys())] x = endpoints[max(endpoints.keys())]
if add_head_batch_norm: if add_head_batch_norm:
x = norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = norm(
axis=axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn,
)(x)
# Depending on the backbone type, backbone's output can be # Depending on the backbone type, backbone's output can be
# [batch_size, height, weight, channel_size] or # [batch_size, height, weight, channel_size] or
......
...@@ -97,10 +97,7 @@ class FPN(tf.keras.Model): ...@@ -97,10 +97,7 @@ class FPN(tf.keras.Model):
conv2d = tf.keras.layers.SeparableConv2D conv2d = tf.keras.layers.SeparableConv2D
else: else:
conv2d = tf.keras.layers.Conv2D conv2d = tf.keras.layers.Conv2D
if use_sync_bn: norm = tf.keras.layers.BatchNormalization
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
activation_fn = tf_utils.get_activation(activation, use_keras_layer=True) activation_fn = tf_utils.get_activation(activation, use_keras_layer=True)
# Build input feature pyramid. # Build input feature pyramid.
...@@ -185,6 +182,7 @@ class FPN(tf.keras.Model): ...@@ -185,6 +182,7 @@ class FPN(tf.keras.Model):
axis=bn_axis, axis=bn_axis,
momentum=norm_momentum, momentum=norm_momentum,
epsilon=norm_epsilon, epsilon=norm_epsilon,
synchronized=use_sync_bn,
name=f'norm_{level}')( name=f'norm_{level}')(
feats[str(level)]) feats[str(level)])
......
...@@ -137,9 +137,7 @@ class NASFPN(tf.keras.Model): ...@@ -137,9 +137,7 @@ class NASFPN(tf.keras.Model):
self._conv_op = (tf.keras.layers.SeparableConv2D self._conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv'] if self._config_dict['use_separable_conv']
else tf.keras.layers.Conv2D) else tf.keras.layers.Conv2D)
self._norm_op = (tf.keras.layers.experimental.SyncBatchNormalization self._norm_op = tf.keras.layers.BatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -148,6 +146,7 @@ class NASFPN(tf.keras.Model): ...@@ -148,6 +146,7 @@ class NASFPN(tf.keras.Model):
'axis': self._bn_axis, 'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
} }
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation)
......
...@@ -120,13 +120,12 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -120,13 +120,12 @@ class DetectionHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'],
}) })
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization bn_op = tf.keras.layers.BatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = { bn_kwargs = {
'axis': self._bn_axis, 'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
} }
self._convs = [] self._convs = []
...@@ -314,13 +313,12 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -314,13 +313,12 @@ class MaskHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'],
}) })
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization bn_op = tf.keras.layers.BatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = { bn_kwargs = {
'axis': self._bn_axis, 'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
} }
self._convs = [] self._convs = []
......
...@@ -108,13 +108,12 @@ class MaskScoring(tf.keras.Model): ...@@ -108,13 +108,12 @@ class MaskScoring(tf.keras.Model):
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'], 'bias_regularizer': self._config_dict['bias_regularizer'],
}) })
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization bn_op = tf.keras.layers.BatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = { bn_kwargs = {
'axis': self._bn_axis, 'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
} }
self._convs = [] self._convs = []
...@@ -324,13 +323,12 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -324,13 +323,12 @@ class SegmentationHead(tf.keras.layers.Layer):
"""Creates the variables of the segmentation head.""" """Creates the variables of the segmentation head."""
use_depthwise_convolution = self._config_dict['use_depthwise_convolution'] use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
conv_op = tf.keras.layers.Conv2D conv_op = tf.keras.layers.Conv2D
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization bn_op = tf.keras.layers.BatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = { bn_kwargs = {
'axis': self._bn_axis, 'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
} }
if self._config_dict['feature_fusion'] in {'deeplabv3plus', if self._config_dict['feature_fusion'] in {'deeplabv3plus',
......
...@@ -90,11 +90,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -90,11 +90,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
channels = input_shape[3] channels = input_shape[3]
self.aspp_layers = [] self.aspp_layers = []
bn_op = tf.keras.layers.BatchNormalization
if self.use_sync_bn:
bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 bn_axis = -1
...@@ -112,7 +108,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -112,7 +108,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op( bn_op(
axis=bn_axis, axis=bn_axis,
momentum=self.batchnorm_momentum, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation) tf.keras.layers.Activation(self.activation)
]) ])
self.aspp_layers.append(conv_sequential) self.aspp_layers.append(conv_sequential)
...@@ -143,7 +140,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -143,7 +140,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op( bn_op(
axis=bn_axis, axis=bn_axis,
momentum=self.batchnorm_momentum, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation) tf.keras.layers.Activation(self.activation)
]) ])
self.aspp_layers.append(conv_sequential) self.aspp_layers.append(conv_sequential)
...@@ -168,7 +166,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -168,7 +166,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op( bn_op(
axis=bn_axis, axis=bn_axis,
momentum=self.batchnorm_momentum, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation) tf.keras.layers.Activation(self.activation)
])) ]))
...@@ -185,7 +184,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -185,7 +184,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op( bn_op(
axis=bn_axis, axis=bn_axis,
momentum=self.batchnorm_momentum, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation), tf.keras.layers.Activation(self.activation),
tf.keras.layers.Dropout(rate=self.dropout) tf.keras.layers.Dropout(rate=self.dropout)
]) ])
......
...@@ -123,11 +123,8 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -123,11 +123,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -150,7 +147,9 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -150,7 +147,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
conv1_padding = 'same' conv1_padding = 'same'
# explicit padding here is added for centernet # explicit padding here is added for centernet
...@@ -171,7 +170,9 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -171,7 +170,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._conv2 = tf.keras.layers.Conv2D( self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -186,7 +187,9 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -186,7 +187,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation( self._squeeze_excitation = nn_layers.SqueezeExcitation(
...@@ -321,10 +324,8 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -321,10 +324,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if use_sync_bn: self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -360,7 +361,9 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -360,7 +361,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -374,7 +377,9 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -374,7 +377,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation1 = tf_utils.get_activation( self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -392,7 +397,9 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -392,7 +397,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation2 = tf_utils.get_activation( self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -408,7 +415,9 @@ class BottleneckBlock(tf.keras.layers.Layer): ...@@ -408,7 +415,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
trainable=self._bn_trainable) trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation3 = tf_utils.get_activation( self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -589,11 +598,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -589,11 +598,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._expand_se_in_filters = expand_se_in_filters self._expand_se_in_filters = expand_se_in_filters
self._output_intermediate_endpoints = output_intermediate_endpoints self._output_intermediate_endpoints = output_intermediate_endpoints
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -628,7 +634,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -628,7 +634,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer = tf_utils.get_activation( self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -648,7 +656,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -648,7 +656,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._depthwise_activation_layer = tf_utils.get_activation( self._depthwise_activation_layer = tf_utils.get_activation(
self._depthwise_activation, use_keras_layer=True) self._depthwise_activation, use_keras_layer=True)
...@@ -686,7 +696,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -686,7 +696,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
if self._stochastic_depth_drop_rate: if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth( self._stochastic_depth = nn_layers.StochasticDepth(
...@@ -812,11 +824,7 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -812,11 +824,7 @@ class ResidualInner(tf.keras.layers.Layer):
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._batch_norm_first = batch_norm_first self._batch_norm_first = batch_norm_first
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -829,7 +837,9 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -829,7 +837,9 @@ class ResidualInner(tf.keras.layers.Layer):
self._batch_norm_0 = self._norm( self._batch_norm_0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_1 = tf.keras.layers.Conv2D( self._conv2d_1 = tf.keras.layers.Conv2D(
filters=self.filters, filters=self.filters,
...@@ -843,7 +853,9 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -843,7 +853,9 @@ class ResidualInner(tf.keras.layers.Layer):
self._batch_norm_1 = self._norm( self._batch_norm_1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_2 = tf.keras.layers.Conv2D( self._conv2d_2 = tf.keras.layers.Conv2D(
filters=self.filters, filters=self.filters,
...@@ -938,11 +950,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -938,11 +950,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._batch_norm_first = batch_norm_first self._batch_norm_first = batch_norm_first
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -955,7 +963,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -955,7 +963,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_0 = self._norm( self._batch_norm_0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_1 = tf.keras.layers.Conv2D( self._conv2d_1 = tf.keras.layers.Conv2D(
filters=self.filters, filters=self.filters,
kernel_size=1, kernel_size=1,
...@@ -967,7 +977,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -967,7 +977,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_1 = self._norm( self._batch_norm_1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_2 = tf.keras.layers.Conv2D( self._conv2d_2 = tf.keras.layers.Conv2D(
filters=self.filters, filters=self.filters,
kernel_size=3, kernel_size=3,
...@@ -979,7 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -979,7 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_2 = self._norm( self._batch_norm_2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_3 = tf.keras.layers.Conv2D( self._conv2d_3 = tf.keras.layers.Conv2D(
filters=self.filters * 4, filters=self.filters * 4,
kernel_size=1, kernel_size=1,
...@@ -1257,11 +1271,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1257,11 +1271,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -1301,7 +1312,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1301,7 +1312,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv1 = tf.keras.layers.Conv2D( self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters, filters=self._filters,
...@@ -1314,7 +1327,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1314,7 +1327,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
super(DepthwiseSeparableConvBlock, self).build(input_shape) super(DepthwiseSeparableConvBlock, self).build(input_shape)
...@@ -1398,11 +1413,8 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1398,11 +1413,8 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -1426,7 +1438,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1426,7 +1438,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer0 = tf_utils.get_activation( self._activation_layer0 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -1447,7 +1461,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1447,7 +1461,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer1 = tf_utils.get_activation( self._activation_layer1 = tf_utils.get_activation(
self._activation, use_keras_layer=True) self._activation, use_keras_layer=True)
...@@ -1464,7 +1480,9 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1464,7 +1480,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
if self._stochastic_depth_drop_rate: if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth( self._stochastic_depth = nn_layers.StochasticDepth(
......
...@@ -130,11 +130,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -130,11 +130,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
else: else:
...@@ -161,7 +158,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -161,7 +158,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm0 = self._norm( self._norm0 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._temporal_conv = tf.keras.layers.Conv3D( self._temporal_conv = tf.keras.layers.Conv3D(
filters=self._filters, filters=self._filters,
...@@ -175,7 +173,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -175,7 +173,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm1 = self._norm( self._norm1 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._spatial_conv = tf.keras.layers.Conv3D( self._spatial_conv = tf.keras.layers.Conv3D(
filters=self._filters, filters=self._filters,
...@@ -189,7 +188,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -189,7 +188,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm2 = self._norm( self._norm2 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._expand_conv = tf.keras.layers.Conv3D( self._expand_conv = tf.keras.layers.Conv3D(
filters=4 * self._filters, filters=4 * self._filters,
...@@ -203,7 +203,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -203,7 +203,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm3 = self._norm( self._norm3 = self._norm(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon) epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1: if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation( self._squeeze_excitation = nn_layers.SqueezeExcitation(
......
...@@ -1135,10 +1135,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1135,10 +1135,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._pool_kernel_size = pool_kernel_size self._pool_kernel_size = pool_kernel_size
self._use_depthwise_convolution = use_depthwise_convolution self._use_depthwise_convolution = use_depthwise_convolution
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn: self._bn_op = tf.keras.layers.BatchNormalization
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -1161,7 +1158,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1161,7 +1158,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm1 = self._bn_op( norm1 = self._bn_op(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon) epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append([conv1, norm1]) self.aspp_layers.append([conv1, norm1])
...@@ -1195,7 +1193,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1195,7 +1193,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm_dilation = self._bn_op( norm_dilation = self._bn_op(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon) epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append(conv_dilation + [norm_dilation]) self.aspp_layers.append(conv_dilation + [norm_dilation])
...@@ -1216,7 +1215,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1216,7 +1215,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm2 = self._bn_op( norm2 = self._bn_op(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon) epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
...@@ -1234,7 +1234,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1234,7 +1234,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._bn_op( self._bn_op(
axis=self._bn_axis, axis=self._bn_axis,
momentum=self._batchnorm_momentum, momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon) epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
] ]
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
self._concat_layer = tf.keras.layers.Concatenate(axis=-1) self._concat_layer = tf.keras.layers.Concatenate(axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册