提交 a9890745 编写于 作者: W Waleed Abdulla

Config setting to enable/disable BN training

Allows using a config variable, TRAIN_NB, rather
that updating the BatchNorm class.
上级 94196f5c
......@@ -148,6 +148,12 @@ class Config(object):
# train the RPN.
USE_RPN_ROIS = True
# Train or freeze batch normalization layers
# None: Train BN layers. This is the normal mode
# False: Freeze BN layers. Good when using a small batch size
# True: (don't use). Set layer in training mode even when inferencing
TRAIN_BN = False # Defaulting to False since batch size is often small
# Gradient norm clipping
GRADIENT_CLIP_NORM = 5.0
......
......@@ -55,16 +55,21 @@ def log(text, array=None):
class BatchNorm(KL.BatchNormalization):
"""Batch Normalization class. Subclasses the Keras BN class and
hardcodes training=False so the BN layer doesn't update
during training.
"""Extends the Keras BatchNormalization class to allow a central place
to make changes if needed.
Batch normalization has a negative effect on training if batches are small
so we disable it here.
so this layer is often frozen (via setting in Config class) and functions
as linear layer.
"""
def call(self, inputs, training=None):
return super(self.__class__, self).call(inputs, training=False)
"""
Note about training values:
None: Train BN layers. This is the normal mode
False: Freeze BN layers. Good when batch size is small
True: (don't use). Set layer in training mode even when inferencing
"""
return super(self.__class__, self).call(inputs, training=training)
############################################################
......@@ -75,7 +80,7 @@ class BatchNorm(KL.BatchNormalization):
# https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py
def identity_block(input_tensor, kernel_size, filters, stage, block,
use_bias=True):
use_bias=True, train_bn=True):
"""The identity_block is the block that has no conv layer at shortcut
# Arguments
input_tensor: input tensor
......@@ -83,6 +88,8 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layres
"""
nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
......@@ -90,17 +97,17 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
x = KL.Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a',
use_bias=use_bias)(input_tensor)
x = BatchNorm(axis=3, name=bn_name_base + '2a')(x)
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
name=conv_name_base + '2b', use_bias=use_bias)(x)
x = BatchNorm(axis=3, name=bn_name_base + '2b')(x)
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c',
use_bias=use_bias)(x)
x = BatchNorm(axis=3, name=bn_name_base + '2c')(x)
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
x = KL.Add()([x, input_tensor])
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
......@@ -108,7 +115,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
def conv_block(input_tensor, kernel_size, filters, stage, block,
strides=(2, 2), use_bias=True):
strides=(2, 2), use_bias=True, train_bn=True):
"""conv_block is the block that has a conv layer at shortcut
# Arguments
input_tensor: input tensor
......@@ -116,6 +123,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block,
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layres
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well
"""
......@@ -125,55 +134,60 @@ def conv_block(input_tensor, kernel_size, filters, stage, block,
x = KL.Conv2D(nb_filter1, (1, 1), strides=strides,
name=conv_name_base + '2a', use_bias=use_bias)(input_tensor)
x = BatchNorm(axis=3, name=bn_name_base + '2a')(x)
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
name=conv_name_base + '2b', use_bias=use_bias)(x)
x = BatchNorm(axis=3, name=bn_name_base + '2b')(x)
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base +
'2c', use_bias=use_bias)(x)
x = BatchNorm(axis=3, name=bn_name_base + '2c')(x)
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
shortcut = KL.Conv2D(nb_filter3, (1, 1), strides=strides,
name=conv_name_base + '1', use_bias=use_bias)(input_tensor)
shortcut = BatchNorm(axis=3, name=bn_name_base + '1')(shortcut)
shortcut = BatchNorm(name=bn_name_base + '1')(shortcut, training=train_bn)
x = KL.Add()([x, shortcut])
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
return x
def resnet_graph(input_image, architecture, stage5=False):
def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
"""Build a ResNet graph.
architecture: Can be resnet50 or resnet101
stage5: Boolean. If False, stage5 of the network is not created
train_bn: Boolean. Train or freeze Batch Norm layres
"""
assert architecture in ["resnet50", "resnet101"]
# Stage 1
x = KL.ZeroPadding2D((3, 3))(input_image)
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True)(x)
x = BatchNorm(axis=3, name='bn_conv1')(x)
x = BatchNorm(name='bn_conv1')(x, training=train_bn)
x = KL.Activation('relu')(x)
C1 = x = KL.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)
# Stage 2
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
C2 = x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), train_bn=train_bn)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', train_bn=train_bn)
C2 = x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', train_bn=train_bn)
# Stage 3
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
C3 = x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', train_bn=train_bn)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', train_bn=train_bn)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', train_bn=train_bn)
C3 = x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', train_bn=train_bn)
# Stage 4
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', train_bn=train_bn)
block_count = {"resnet50": 5, "resnet101": 22}[architecture]
for i in range(block_count):
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=chr(98 + i))
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=chr(98 + i), train_bn=train_bn)
C4 = x
# Stage 5
if stage5:
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
C5 = x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', train_bn=train_bn)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', train_bn=train_bn)
C5 = x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', train_bn=train_bn)
else:
C5 = None
return [C1, C2, C3, C4, C5]
......@@ -883,7 +897,7 @@ def build_rpn_model(anchor_stride, anchors_per_location, depth):
############################################################
def fpn_classifier_graph(rois, feature_maps,
image_shape, pool_size, num_classes):
image_shape, pool_size, num_classes, train_bn=True):
"""Builds the computation graph of the feature pyramid network classifier
and regressor heads.
......@@ -894,6 +908,7 @@ def fpn_classifier_graph(rois, feature_maps,
image_shape: [height, width, depth]
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layres
Returns:
logits: [N, NUM_CLASSES] classifier logits (before softmax)
......@@ -908,12 +923,11 @@ def fpn_classifier_graph(rois, feature_maps,
# Two 1024 FC layers (implemented with Conv2D for consistency)
x = KL.TimeDistributed(KL.Conv2D(1024, (pool_size, pool_size), padding="valid"),
name="mrcnn_class_conv1")(x)
x = KL.TimeDistributed(BatchNorm(axis=3), name='mrcnn_class_bn1')(x)
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn1')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(1024, (1, 1)),
name="mrcnn_class_conv2")(x)
x = KL.TimeDistributed(BatchNorm(axis=3),
name='mrcnn_class_bn2')(x)
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn2')(x, training=train_bn)
x = KL.Activation('relu')(x)
shared = KL.Lambda(lambda x: K.squeeze(K.squeeze(x, 3), 2),
......@@ -936,8 +950,8 @@ def fpn_classifier_graph(rois, feature_maps,
return mrcnn_class_logits, mrcnn_probs, mrcnn_bbox
def build_fpn_mask_graph(rois, feature_maps,
image_shape, pool_size, num_classes):
def build_fpn_mask_graph(rois, feature_maps, image_shape,
pool_size, num_classes, train_bn=True):
"""Builds the computation graph of the mask head of Feature Pyramid Network.
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
......@@ -947,6 +961,7 @@ def build_fpn_mask_graph(rois, feature_maps,
image_shape: [height, width, depth]
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layres
Returns: Masks [batch, roi_count, height, width, num_classes]
"""
......@@ -958,26 +973,26 @@ def build_fpn_mask_graph(rois, feature_maps,
# Conv layers
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv1")(x)
x = KL.TimeDistributed(BatchNorm(axis=3),
name='mrcnn_mask_bn1')(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn1')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv2")(x)
x = KL.TimeDistributed(BatchNorm(axis=3),
name='mrcnn_mask_bn2')(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn2')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv3")(x)
x = KL.TimeDistributed(BatchNorm(axis=3),
name='mrcnn_mask_bn3')(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn3')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv4")(x)
x = KL.TimeDistributed(BatchNorm(axis=3),
name='mrcnn_mask_bn4')(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn4')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2DTranspose(256, (2, 2), strides=2, activation="relu"),
......@@ -1861,7 +1876,8 @@ class MaskRCNN():
# Bottom-up Layers
# Returns a list of the last layers of each stage, 5 in total.
# Don't create the thead (stage 5), so we pick the 4th item in the list.
_, C2, C3, C4, C5 = resnet_graph(input_image, config.BACKBONE, stage5=True)
_, C2, C3, C4, C5 = resnet_graph(input_image, config.BACKBONE,
stage5=True, train_bn=config.TRAIN_BN)
# Top-down Layers
# TODO: add assert to varify feature map sizes match what's in config
P5 = KL.Conv2D(256, (1, 1), name='fpn_c5p5')(C5)
......@@ -1951,12 +1967,14 @@ class MaskRCNN():
# TODO: verify that this handles zero padded ROIs
mrcnn_class_logits, mrcnn_class, mrcnn_bbox =\
fpn_classifier_graph(rois, mrcnn_feature_maps, config.IMAGE_SHAPE,
config.POOL_SIZE, config.NUM_CLASSES)
config.POOL_SIZE, config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
# TODO: clean up (use tf.identify if necessary)
output_rois = KL.Lambda(lambda x: x * 1, name="output_rois")(rois)
......@@ -1988,7 +2006,8 @@ class MaskRCNN():
# Proposal classifier and BBox regressor heads
mrcnn_class_logits, mrcnn_class, mrcnn_bbox =\
fpn_classifier_graph(rpn_rois, mrcnn_feature_maps, config.IMAGE_SHAPE,
config.POOL_SIZE, config.NUM_CLASSES)
config.POOL_SIZE, config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
# Detections
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in image coordinates
......@@ -2006,7 +2025,8 @@ class MaskRCNN():
mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps,
config.IMAGE_SHAPE,
config.MASK_POOL_SIZE,
config.NUM_CLASSES)
config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
model = KM.Model([input_image, input_image_meta],
[detections, mrcnn_class, mrcnn_bbox,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册