提交 350243e2 编写于 作者: A A. Unique TensorFlower

Remove Deprecated `tf.keras.layers.experimental.SyncBatchNormalization` from...

Remove Deprecated `tf.keras.layers.experimental.SyncBatchNormalization` from existing code and update `tf.keras.layers.experimental.BatchNormalization` with `synchronized` argument.

PiperOrigin-RevId: 526102034
上级 3b4caa8f
......@@ -151,10 +151,7 @@ class EfficientNet(tf.keras.Model):
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
......@@ -178,7 +175,10 @@ class EfficientNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
......@@ -210,7 +210,10 @@ class EfficientNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
endpoints[str(endpoint_level)] = tf_utils.get_activation(activation)(x)
......
......@@ -91,15 +91,12 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._norm = tf.keras.layers.BatchNormalization
if use_explicit_padding and kernel_size > 1:
self._padding = 'valid'
else:
self._padding = 'same'
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -141,7 +138,8 @@ class Conv2DBNBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......
......@@ -179,10 +179,7 @@ class ResNet(tf.keras.Model):
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
......@@ -238,6 +235,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
elif self._stem_type == 'v1':
......@@ -256,6 +254,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
x = layers.Conv2D(
......@@ -273,6 +272,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
x = layers.Conv2D(
......@@ -290,6 +290,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
else:
......@@ -311,6 +312,7 @@ class ResNet(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
else:
......
......@@ -145,10 +145,7 @@ class ResNet3D(tf.keras.Model):
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
......@@ -232,7 +229,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
elif stem_type == 'v1':
x = layers.Conv3D(
......@@ -248,7 +246,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D(
filters=32,
......@@ -263,7 +262,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D(
filters=64,
......@@ -278,7 +278,8 @@ class ResNet3D(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(x)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(x)
x = tf_utils.get_activation(self._activation)(x)
else:
raise ValueError(f'Stem type {stem_type} not supported.')
......
......@@ -126,10 +126,7 @@ class DilatedResNet(tf.keras.Model):
self._activation = activation
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
......@@ -159,7 +156,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
......@@ -174,7 +174,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
......@@ -188,7 +191,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
......@@ -202,7 +208,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation)(x)
else:
......@@ -220,7 +229,10 @@ class DilatedResNet(tf.keras.Model):
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(
x)
x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
else:
......
......@@ -93,10 +93,7 @@ class RevNet(tf.keras.Model):
self._norm_epsilon = norm_epsilon
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
......@@ -109,7 +106,10 @@ class RevNet(tf.keras.Model):
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer)(inputs)
x = self._norm(
axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
axis=axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn)(x)
x = tf_utils.get_activation(activation)(x)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
......
......@@ -205,11 +205,7 @@ class SpineNet(tf.keras.Model):
self._num_init_blocks = 2
self._set_activation_fn(activation)
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -303,7 +299,8 @@ class SpineNet(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
......@@ -434,7 +431,8 @@ class SpineNet(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
endpoints[str(level)] = x
......@@ -466,7 +464,8 @@ class SpineNet(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
......@@ -485,7 +484,8 @@ class SpineNet(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
input_width /= 2
......@@ -511,7 +511,8 @@ class SpineNet(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
return x
......
......@@ -205,11 +205,7 @@ class SpineNetMobile(tf.keras.Model):
self._norm_epsilon = norm_epsilon
self._use_keras_upsampling_2d = use_keras_upsampling_2d
self._num_init_blocks = 2
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -290,7 +286,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
......@@ -428,7 +425,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
endpoints[str(level)] = x
......@@ -453,7 +451,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(x)
......@@ -476,7 +475,8 @@ class SpineNetMobile(tf.keras.Model):
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)(
x)
return x
......
......@@ -62,10 +62,7 @@ class ClassificationModel(tf.keras.Model):
skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed.
"""
if use_sync_bn:
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
norm = tf.keras.layers.BatchNormalization
axis = -1 if tf.keras.backend.image_data_format() == 'channels_last' else 1
inputs = tf.keras.Input(shape=input_specs.shape[1:], name=input_specs.name)
......@@ -73,7 +70,12 @@ class ClassificationModel(tf.keras.Model):
x = endpoints[max(endpoints.keys())]
if add_head_batch_norm:
x = norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
x = norm(
axis=axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn,
)(x)
# Depending on the backbone type, backbone's output can be
# [batch_size, height, weight, channel_size] or
......
......@@ -97,10 +97,7 @@ class FPN(tf.keras.Model):
conv2d = tf.keras.layers.SeparableConv2D
else:
conv2d = tf.keras.layers.Conv2D
if use_sync_bn:
norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
norm = tf.keras.layers.BatchNormalization
norm = tf.keras.layers.BatchNormalization
activation_fn = tf_utils.get_activation(activation, use_keras_layer=True)
# Build input feature pyramid.
......@@ -185,6 +182,7 @@ class FPN(tf.keras.Model):
axis=bn_axis,
momentum=norm_momentum,
epsilon=norm_epsilon,
synchronized=use_sync_bn,
name=f'norm_{level}')(
feats[str(level)])
......
......@@ -137,9 +137,7 @@ class NASFPN(tf.keras.Model):
self._conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
else tf.keras.layers.Conv2D)
self._norm_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
self._norm_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -148,6 +146,7 @@ class NASFPN(tf.keras.Model):
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
self._activation = tf_utils.get_activation(activation)
......
......@@ -120,13 +120,12 @@ class DetectionHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_op = tf.keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
self._convs = []
......@@ -314,13 +313,12 @@ class MaskHead(tf.keras.layers.Layer):
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_op = tf.keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
self._convs = []
......
......@@ -108,13 +108,12 @@ class MaskScoring(tf.keras.Model):
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_op = tf.keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
self._convs = []
......@@ -324,13 +323,12 @@ class SegmentationHead(tf.keras.layers.Layer):
"""Creates the variables of the segmentation head."""
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
conv_op = tf.keras.layers.Conv2D
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_op = tf.keras.layers.BatchNormalization
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
'synchronized': self._config_dict['use_sync_bn'],
}
if self._config_dict['feature_fusion'] in {'deeplabv3plus',
......
......@@ -90,11 +90,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
channels = input_shape[3]
self.aspp_layers = []
if self.use_sync_bn:
bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
bn_op = tf.keras.layers.BatchNormalization
bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
......@@ -112,7 +108,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation)
])
self.aspp_layers.append(conv_sequential)
......@@ -143,7 +140,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation)
])
self.aspp_layers.append(conv_sequential)
......@@ -168,7 +166,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation)
]))
......@@ -185,7 +184,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
bn_op(
axis=bn_axis,
momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon),
epsilon=self.batchnorm_epsilon,
synchronized=self.use_sync_bn),
tf.keras.layers.Activation(self.activation),
tf.keras.layers.Dropout(rate=self.dropout)
])
......
......@@ -123,11 +123,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -150,7 +147,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
conv1_padding = 'same'
# explicit padding here is added for centernet
......@@ -171,7 +170,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -186,7 +187,9 @@ class ResidualBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
......@@ -321,10 +324,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -360,7 +361,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -374,7 +377,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation1 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -392,7 +397,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation2 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -408,7 +415,9 @@ class BottleneckBlock(tf.keras.layers.Layer):
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
trainable=self._bn_trainable)
trainable=self._bn_trainable,
synchronized=self._use_sync_bn,
)
self._activation3 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -589,11 +598,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._bias_regularizer = bias_regularizer
self._expand_se_in_filters = expand_se_in_filters
self._output_intermediate_endpoints = output_intermediate_endpoints
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -628,7 +634,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -648,7 +656,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._depthwise_activation_layer = tf_utils.get_activation(
self._depthwise_activation, use_keras_layer=True)
......@@ -686,7 +696,9 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
......@@ -812,11 +824,7 @@ class ResidualInner(tf.keras.layers.Layer):
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._batch_norm_first = batch_norm_first
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -829,7 +837,9 @@ class ResidualInner(tf.keras.layers.Layer):
self._batch_norm_0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_1 = tf.keras.layers.Conv2D(
filters=self.filters,
......@@ -843,7 +853,9 @@ class ResidualInner(tf.keras.layers.Layer):
self._batch_norm_1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_2 = tf.keras.layers.Conv2D(
filters=self.filters,
......@@ -938,11 +950,7 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._batch_norm_first = batch_norm_first
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -955,7 +963,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_1 = tf.keras.layers.Conv2D(
filters=self.filters,
kernel_size=1,
......@@ -967,7 +977,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_2 = tf.keras.layers.Conv2D(
filters=self.filters,
kernel_size=3,
......@@ -979,7 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
self._batch_norm_2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv2d_3 = tf.keras.layers.Conv2D(
filters=self.filters * 4,
kernel_size=1,
......@@ -1257,11 +1271,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -1301,7 +1312,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
......@@ -1314,7 +1327,9 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
super(DepthwiseSeparableConvBlock, self).build(input_shape)
......@@ -1398,11 +1413,8 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -1426,7 +1438,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer0 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -1447,7 +1461,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
self._activation_layer1 = tf_utils.get_activation(
self._activation, use_keras_layer=True)
......@@ -1464,7 +1480,9 @@ class TuckerConvBlock(tf.keras.layers.Layer):
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn,
)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
......
......@@ -130,11 +130,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._norm = tf.keras.layers.BatchNormalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
......@@ -161,7 +158,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._temporal_conv = tf.keras.layers.Conv3D(
filters=self._filters,
......@@ -175,7 +173,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._spatial_conv = tf.keras.layers.Conv3D(
filters=self._filters,
......@@ -189,7 +188,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
self._expand_conv = tf.keras.layers.Conv3D(
filters=4 * self._filters,
......@@ -203,7 +203,8 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
epsilon=self._norm_epsilon,
synchronized=self._use_sync_bn)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
......
......@@ -1135,10 +1135,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._pool_kernel_size = pool_kernel_size
self._use_depthwise_convolution = use_depthwise_convolution
self._activation_fn = tf_utils.get_activation(activation)
if self._use_sync_bn:
self._bn_op = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._bn_op = tf.keras.layers.BatchNormalization
self._bn_op = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
......@@ -1161,7 +1158,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm1 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append([conv1, norm1])
......@@ -1195,7 +1193,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm_dilation = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append(conv_dilation + [norm_dilation])
......@@ -1216,7 +1215,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
norm2 = self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
self.aspp_layers.append(pooling + [conv2, norm2])
......@@ -1234,7 +1234,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._bn_op(
axis=self._bn_axis,
momentum=self._batchnorm_momentum,
epsilon=self._batchnorm_epsilon)
epsilon=self._batchnorm_epsilon,
synchronized=self._use_sync_bn)
]
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
self._concat_layer = tf.keras.layers.Concatenate(axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册