提交 8075ecc1 编写于 作者: C Corey Hu 提交者: Waleed

edit loss desc

上级 8bed8428
""" """
Mask R-CNN Mask R-CNN
The main Mask R-CNN model implemenetation. The main Mask R-CNN model implementation.
Copyright (c) 2017 Matterport, Inc. Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details) Licensed under the MIT License (see LICENSE for details)
...@@ -63,7 +63,7 @@ class BatchNorm(KL.BatchNormalization): ...@@ -63,7 +63,7 @@ class BatchNorm(KL.BatchNormalization):
Note about training values: Note about training values:
None: Train BN layers. This is the normal mode None: Train BN layers. This is the normal mode
False: Freeze BN layers. Good when batch size is small False: Freeze BN layers. Good when batch size is small
True: (don't use). Set layer in training mode even when inferencing True: (don't use). Set layer in training mode even when making inferences
""" """
return super(self.__class__, self).call(inputs, training=training) return super(self.__class__, self).call(inputs, training=training)
...@@ -97,12 +97,12 @@ def identity_block(input_tensor, kernel_size, filters, stage, block, ...@@ -97,12 +97,12 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
"""The identity_block is the block that has no conv layer at shortcut """The identity_block is the block that has no conv layer at shortcut
# Arguments # Arguments
input_tensor: input tensor input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers. use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layres train_bn: Boolean. Train or freeze Batch Norm layers
""" """
nb_filter1, nb_filter2, nb_filter3 = filters nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch' conv_name_base = 'res' + str(stage) + block + '_branch'
...@@ -132,12 +132,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, ...@@ -132,12 +132,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block,
"""conv_block is the block that has a conv layer at shortcut """conv_block is the block that has a conv layer at shortcut
# Arguments # Arguments
input_tensor: input tensor input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers. use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layres train_bn: Boolean. Train or freeze Batch Norm layers
Note that from stage 3, the first conv layer at main path is with subsample=(2,2) Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well And the shortcut should have subsample=(2,2) as well
""" """
...@@ -172,7 +172,7 @@ def resnet_graph(input_image, architecture, stage5=False, train_bn=True): ...@@ -172,7 +172,7 @@ def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
"""Build a ResNet graph. """Build a ResNet graph.
architecture: Can be resnet50 or resnet101 architecture: Can be resnet50 or resnet101
stage5: Boolean. If False, stage5 of the network is not created stage5: Boolean. If False, stage5 of the network is not created
train_bn: Boolean. Train or freeze Batch Norm layres train_bn: Boolean. Train or freeze Batch Norm layers
""" """
assert architecture in ["resnet50", "resnet101"] assert architecture in ["resnet50", "resnet101"]
# Stage 1 # Stage 1
...@@ -337,7 +337,7 @@ class ProposalLayer(KE.Layer): ...@@ -337,7 +337,7 @@ class ProposalLayer(KE.Layer):
############################################################ ############################################################
def log2_graph(x): def log2_graph(x):
"""Implementatin of Log2. TF doesn't have a native implemenation.""" """Implementation of Log2. TF doesn't have a native implementation."""
return tf.log(x) / tf.log(2.0) return tf.log(x) / tf.log(2.0)
...@@ -399,7 +399,7 @@ class PyramidROIAlign(KE.Layer): ...@@ -399,7 +399,7 @@ class PyramidROIAlign(KE.Layer):
ix = tf.where(tf.equal(roi_level, level)) ix = tf.where(tf.equal(roi_level, level))
level_boxes = tf.gather_nd(boxes, ix) level_boxes = tf.gather_nd(boxes, ix)
# Box indicies for crop_and_resize. # Box indices for crop_and_resize.
box_indices = tf.cast(ix[:, 0], tf.int32) box_indices = tf.cast(ix[:, 0], tf.int32)
# Keep track of which box is mapped to which level # Keep track of which box is mapped to which level
...@@ -457,9 +457,9 @@ def overlaps_graph(boxes1, boxes2): ...@@ -457,9 +457,9 @@ def overlaps_graph(boxes1, boxes2):
"""Computes IoU overlaps between two sets of boxes. """Computes IoU overlaps between two sets of boxes.
boxes1, boxes2: [N, (y1, x1, y2, x2)]. boxes1, boxes2: [N, (y1, x1, y2, x2)].
""" """
# 1. Tile boxes2 and repeate boxes1. This allows us to compare # 1. Tile boxes2 and repeat boxes1. This allows us to compare
# every boxes1 against every boxes2 without loops. # every boxes1 against every boxes2 without loops.
# TF doesn't have an equivalent to np.repeate() so simulate it # TF doesn't have an equivalent to np.repeat() so simulate it
# using tf.tile() and tf.reshape. # using tf.tile() and tf.reshape.
b1 = tf.reshape(tf.tile(tf.expand_dims(boxes1, 1), b1 = tf.reshape(tf.tile(tf.expand_dims(boxes1, 1),
[1, 1, tf.shape(boxes2)[0]]), [-1, 4]) [1, 1, tf.shape(boxes2)[0]]), [-1, 4])
...@@ -539,7 +539,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config) ...@@ -539,7 +539,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1) crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < 0.001) no_crowd_bool = (crowd_iou_max < 0.001)
# Determine postive and negative ROIs # Determine positive and negative ROIs
roi_iou_max = tf.reduce_max(overlaps, axis=1) roi_iou_max = tf.reduce_max(overlaps, axis=1)
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
positive_roi_bool = (roi_iou_max >= 0.5) positive_roi_bool = (roi_iou_max >= 0.5)
...@@ -584,7 +584,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config) ...@@ -584,7 +584,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
# Compute mask targets # Compute mask targets
boxes = positive_rois boxes = positive_rois
if config.USE_MINI_MASK: if config.USE_MINI_MASK:
# Transform ROI corrdinates from normalized image space # Transform ROI coordinates from normalized image space
# to normalized mini-mask space. # to normalized mini-mask space.
y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1) y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1)
gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1) gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1)
...@@ -741,7 +741,7 @@ def refine_detections_graph(rois, probs, deltas, window, config): ...@@ -741,7 +741,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
tf.gather(pre_nms_scores, ixs), tf.gather(pre_nms_scores, ixs),
max_output_size=config.DETECTION_MAX_INSTANCES, max_output_size=config.DETECTION_MAX_INSTANCES,
iou_threshold=config.DETECTION_NMS_THRESHOLD) iou_threshold=config.DETECTION_NMS_THRESHOLD)
# Map indicies # Map indices
class_keep = tf.gather(keep, tf.gather(ixs, class_keep)) class_keep = tf.gather(keep, tf.gather(ixs, class_keep))
# Pad with -1 so returned tensors have the same shape # Pad with -1 so returned tensors have the same shape
gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0] gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0]
...@@ -844,8 +844,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride): ...@@ -844,8 +844,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
rpn_bbox: [batch, H, W, (dy, dx, log(dh), log(dw))] Deltas to be rpn_bbox: [batch, H, W, (dy, dx, log(dh), log(dw))] Deltas to be
applied to anchors. applied to anchors.
""" """
# TODO: check if stride of 2 causes alignment issues if the featuremap # TODO: check if stride of 2 causes alignment issues if the feature map
# is not even. # is not even.
# Shared convolutional base of the RPN # Shared convolutional base of the RPN
shared = KL.Conv2D(512, (3, 3), padding='same', activation='relu', shared = KL.Conv2D(512, (3, 3), padding='same', activation='relu',
strides=anchor_stride, strides=anchor_stride,
...@@ -908,12 +908,12 @@ def fpn_classifier_graph(rois, feature_maps, image_meta, ...@@ -908,12 +908,12 @@ def fpn_classifier_graph(rois, feature_maps, image_meta,
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates. coordinates.
feature_maps: List of feature maps from diffent layers of the pyramid, feature_maps: List of feature maps from different layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution. [P2, P3, P4, P5]. Each has a different resolution.
- image_meta: [batch, (meta data)] Image details. See compose_image_meta() - image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling. pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layres train_bn: Boolean. Train or freeze Batch Norm layers
fc_layers_size: Size of the 2 FC layers fc_layers_size: Size of the 2 FC layers
Returns: Returns:
...@@ -962,12 +962,12 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta, ...@@ -962,12 +962,12 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates. coordinates.
feature_maps: List of feature maps from diffent layers of the pyramid, feature_maps: List of feature maps from different layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution. [P2, P3, P4, P5]. Each has a different resolution.
image_meta: [batch, (meta data)] Image details. See compose_image_meta() image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling. pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layres train_bn: Boolean. Train or freeze Batch Norm layers
Returns: Masks [batch, roi_count, height, width, num_classes] Returns: Masks [batch, roi_count, height, width, num_classes]
""" """
...@@ -1014,7 +1014,7 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta, ...@@ -1014,7 +1014,7 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
def smooth_l1_loss(y_true, y_pred): def smooth_l1_loss(y_true, y_pred):
"""Implements Smooth-L1 loss. """Implements Smooth-L1 loss.
y_true and y_pred are typicallly: [N, 4], but could be any shape. y_true and y_pred are typically: [N, 4], but could be any shape.
""" """
diff = K.abs(y_true - y_pred) diff = K.abs(y_true - y_pred)
less_than_one = K.cast(K.less(diff, 1.0), "float32") less_than_one = K.cast(K.less(diff, 1.0), "float32")
...@@ -1039,7 +1039,7 @@ def rpn_class_loss_graph(rpn_match, rpn_class_logits): ...@@ -1039,7 +1039,7 @@ def rpn_class_loss_graph(rpn_match, rpn_class_logits):
# Pick rows that contribute to the loss and filter out the rest. # Pick rows that contribute to the loss and filter out the rest.
rpn_class_logits = tf.gather_nd(rpn_class_logits, indices) rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
anchor_class = tf.gather_nd(anchor_class, indices) anchor_class = tf.gather_nd(anchor_class, indices)
# Crossentropy loss # Cross entropy loss
loss = K.sparse_categorical_crossentropy(target=anchor_class, loss = K.sparse_categorical_crossentropy(target=anchor_class,
output=rpn_class_logits, output=rpn_class_logits,
from_logits=True) from_logits=True)
...@@ -1129,7 +1129,7 @@ def mrcnn_bbox_loss_graph(target_bbox, target_class_ids, pred_bbox): ...@@ -1129,7 +1129,7 @@ def mrcnn_bbox_loss_graph(target_bbox, target_class_ids, pred_bbox):
pred_bbox = K.reshape(pred_bbox, (-1, K.int_shape(pred_bbox)[2], 4)) pred_bbox = K.reshape(pred_bbox, (-1, K.int_shape(pred_bbox)[2], 4))
# Only positive ROIs contribute to the loss. And only # Only positive ROIs contribute to the loss. And only
# the right class_id of each ROI. Get their indicies. # the right class_id of each ROI. Get their indices.
positive_roi_ix = tf.where(target_class_ids > 0)[:, 0] positive_roi_ix = tf.where(target_class_ids > 0)[:, 0]
positive_roi_class_ids = tf.cast( positive_roi_class_ids = tf.cast(
tf.gather(target_class_ids, positive_roi_ix), tf.int64) tf.gather(target_class_ids, positive_roi_ix), tf.int64)
...@@ -1194,7 +1194,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None, ...@@ -1194,7 +1194,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
use_mini_mask=False): use_mini_mask=False):
"""Load and return ground truth data for an image (image, mask, bounding boxes). """Load and return ground truth data for an image (image, mask, bounding boxes).
augment: (Depricated. Use augmentation instead). If true, apply random augment: (deprecated. Use augmentation instead). If true, apply random
image augmentation. Currently, only horizontal flipping is offered. image augmentation. Currently, only horizontal flipping is offered.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation. augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
For example, passing imgaug.augmenters.Fliplr(0.5) flips images For example, passing imgaug.augmenters.Fliplr(0.5) flips images
...@@ -1229,7 +1229,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None, ...@@ -1229,7 +1229,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
# Random horizontal flips. # Random horizontal flips.
# TODO: will be removed in a future update in favor of augmentation # TODO: will be removed in a future update in favor of augmentation
if augment: if augment:
logging.warning("'augment' is depricated. Use 'augmentation' instead.") logging.warning("'augment' is deprecated. Use 'augmentation' instead.")
if random.randint(0, 1): if random.randint(0, 1):
image = np.fliplr(image) image = np.fliplr(image)
mask = np.fliplr(mask) mask = np.fliplr(mask)
...@@ -1239,7 +1239,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None, ...@@ -1239,7 +1239,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
if augmentation: if augmentation:
import imgaug import imgaug
# Augmentors that are safe to apply to masks # Augmenters that are safe to apply to masks
# Some, such as Affine, have settings that make them unsafe, so always # Some, such as Affine, have settings that make them unsafe, so always
# test your augmentation on masks # test your augmentation on masks
MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes", MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes",
...@@ -1248,7 +1248,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None, ...@@ -1248,7 +1248,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
def hook(images, augmenter, parents, default): def hook(images, augmenter, parents, default):
"""Determines which augmenters to apply to masks.""" """Determines which augmenters to apply to masks."""
return (augmenter.__class__.__name__ in MASK_AUGMENTERS) return augmenter.__class__.__name__ in MASK_AUGMENTERS
# Store shapes before augmentation to compare # Store shapes before augmentation to compare
image_shape = image.shape image_shape = image.shape
...@@ -1302,7 +1302,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config): ...@@ -1302,7 +1302,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes. rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes.
gt_class_ids: [instance count] Integer class IDs gt_class_ids: [instance count] Integer class IDs
gt_boxes: [instance count, (y1, x1, y2, x2)] gt_boxes: [instance count, (y1, x1, y2, x2)]
gt_masks: [height, width, instance count] Grund truth masks. Can be full gt_masks: [height, width, instance count] Ground truth masks. Can be full
size or mini-masks. size or mini-masks.
Returns: Returns:
...@@ -1357,7 +1357,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config): ...@@ -1357,7 +1357,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
# Negative ROIs are those with max IoU 0.1-0.5 (hard example mining) # Negative ROIs are those with max IoU 0.1-0.5 (hard example mining)
# TODO: To hard example mine or not to hard example mine, that's the question # TODO: To hard example mine or not to hard example mine, that's the question
# bg_ids = np.where((rpn_roi_iou_max >= 0.1) & (rpn_roi_iou_max < 0.5))[0] # bg_ids = np.where((rpn_roi_iou_max >= 0.1) & (rpn_roi_iou_max < 0.5))[0]
bg_ids = np.where(rpn_roi_iou_max < 0.5)[0] bg_ids = np.where(rpn_roi_iou_max < 0.5)[0]
# Subsample ROIs. Aim for 33% foreground. # Subsample ROIs. Aim for 33% foreground.
...@@ -1373,7 +1373,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config): ...@@ -1373,7 +1373,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
keep_bg_ids = np.random.choice(bg_ids, remaining, replace=False) keep_bg_ids = np.random.choice(bg_ids, remaining, replace=False)
else: else:
keep_bg_ids = bg_ids keep_bg_ids = bg_ids
# Combine indicies of ROIs to keep # Combine indices of ROIs to keep
keep = np.concatenate([keep_fg_ids, keep_bg_ids]) keep = np.concatenate([keep_fg_ids, keep_bg_ids])
# Need more? # Need more?
remaining = config.TRAIN_ROIS_PER_IMAGE - keep.shape[0] remaining = config.TRAIN_ROIS_PER_IMAGE - keep.shape[0]
...@@ -1644,7 +1644,7 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No ...@@ -1644,7 +1644,7 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
dataset: The Dataset object to pick data from dataset: The Dataset object to pick data from
config: The model config object config: The model config object
shuffle: If True, shuffles the samples before every epoch shuffle: If True, shuffles the samples before every epoch
augment: (Depricated. Use augmentation instead). If true, apply random augment: (deprecated. Use augmentation instead). If true, apply random
image augmentation. Currently, only horizontal flipping is offered. image augmentation. Currently, only horizontal flipping is offered.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation. augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
For example, passing imgaug.augmenters.Fliplr(0.5) flips images For example, passing imgaug.augmenters.Fliplr(0.5) flips images
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册