提交 e12bd6a5 编写于 作者: P Pengchong Jin 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 277961522
上级 ada2ed77
...@@ -232,11 +232,6 @@ class Parser(object): ...@@ -232,11 +232,6 @@ class Parser(object):
offset = image_info[3, :] offset = image_info[3, :]
boxes = input_utils.resize_and_crop_boxes( boxes = input_utils.resize_and_crop_boxes(
boxes, image_scale, (image_height, image_width), offset) boxes, image_scale, (image_height, image_width), offset)
if self._include_mask:
masks = input_utils.resize_and_crop_masks(
tf.expand_dims(masks, axis=-1),
image_scale, (image_height, image_width), offset)
masks = tf.squeeze(masks, axis=-1)
# Filters out ground truth boxes that are all zeros. # Filters out ground truth boxes that are all zeros.
indices = input_utils.get_non_empty_box_indices(boxes) indices = input_utils.get_non_empty_box_indices(boxes)
...@@ -244,10 +239,14 @@ class Parser(object): ...@@ -244,10 +239,14 @@ class Parser(object):
classes = tf.gather(classes, indices) classes = tf.gather(classes, indices)
if self._include_mask: if self._include_mask:
masks = tf.gather(masks, indices) masks = tf.gather(masks, indices)
cropped_boxes = boxes + tf.cast(
tf.tile(tf.expand_dims(offset, axis=0), [1, 2]), dtype=tf.float32)
cropped_boxes = box_utils.normalize_boxes(
cropped_boxes, image_info[1, :])
num_masks = tf.shape(masks)[0] num_masks = tf.shape(masks)[0]
masks = tf.image.crop_and_resize( masks = tf.image.crop_and_resize(
tf.expand_dims(masks, axis=-1), tf.expand_dims(masks, axis=-1),
box_utils.normalize_boxes(boxes, tf.shape(image)[0:2]), cropped_boxes,
box_indices=tf.range(num_masks, dtype=tf.int32), box_indices=tf.range(num_masks, dtype=tf.int32),
crop_size=[self._mask_crop_size, self._mask_crop_size], crop_size=[self._mask_crop_size, self._mask_crop_size],
method='bilinear') method='bilinear')
......
...@@ -104,7 +104,7 @@ def fast_rcnn_head_generator(params): ...@@ -104,7 +104,7 @@ def fast_rcnn_head_generator(params):
def mask_rcnn_head_generator(params): def mask_rcnn_head_generator(params):
"""Generator function for Mask R-CNN head architecture.""" """Generator function for Mask R-CNN head architecture."""
return heads.MaskrcnnHead(params.num_classes, return heads.MaskrcnnHead(params.num_classes,
params.mrcnn_resolution, params.mask_target_size,
batch_norm_relu=batch_norm_relu_generator( batch_norm_relu=batch_norm_relu_generator(
params.batch_norm)) params.batch_norm))
......
...@@ -177,18 +177,18 @@ class MaskrcnnHead(object): ...@@ -177,18 +177,18 @@ class MaskrcnnHead(object):
def __init__(self, def __init__(self,
num_classes, num_classes,
mrcnn_resolution, mask_target_size,
batch_norm_relu=nn_ops.BatchNormRelu): batch_norm_relu=nn_ops.BatchNormRelu):
"""Initialize params to build Fast R-CNN head. """Initialize params to build Fast R-CNN head.
Args: Args:
num_classes: a integer for the number of classes. num_classes: a integer for the number of classes.
mrcnn_resolution: a integer that is the resolution of masks. mask_target_size: a integer that is the resolution of masks.
batch_norm_relu: an operation that includes a batch normalization layer batch_norm_relu: an operation that includes a batch normalization layer
followed by a relu layer(optional). followed by a relu layer(optional).
""" """
self._num_classes = num_classes self._num_classes = num_classes
self._mrcnn_resolution = mrcnn_resolution self._mask_target_size = mask_target_size
self._batch_norm_relu = batch_norm_relu self._batch_norm_relu = batch_norm_relu
def __call__(self, roi_features, class_indices, is_training=None): def __call__(self, roi_features, class_indices, is_training=None):
...@@ -272,7 +272,7 @@ class MaskrcnnHead(object): ...@@ -272,7 +272,7 @@ class MaskrcnnHead(object):
name='mask_fcn_logits')( name='mask_fcn_logits')(
net) net)
mask_outputs = tf.reshape(mask_outputs, [ mask_outputs = tf.reshape(mask_outputs, [
-1, num_rois, self._mrcnn_resolution, self._mrcnn_resolution, -1, num_rois, self._mask_target_size, self._mask_target_size,
self._num_classes self._num_classes
]) ])
......
...@@ -220,8 +220,8 @@ def sample_and_crop_foreground_masks(candidate_rois, ...@@ -220,8 +220,8 @@ def sample_and_crop_foreground_masks(candidate_rois,
candidate_gt_classes, candidate_gt_classes,
candidate_gt_indices, candidate_gt_indices,
gt_masks, gt_masks,
num_mask_samples_per_image=28, num_mask_samples_per_image=128,
cropped_mask_size=28): mask_target_size=28):
"""Samples and creates cropped foreground masks for training. """Samples and creates cropped foreground masks for training.
Args: Args:
...@@ -243,7 +243,7 @@ def sample_and_crop_foreground_masks(candidate_rois, ...@@ -243,7 +243,7 @@ def sample_and_crop_foreground_masks(candidate_rois,
containing all the groundtruth masks which sample masks are drawn from. containing all the groundtruth masks which sample masks are drawn from.
num_mask_samples_per_image: an integer which specifies the number of masks num_mask_samples_per_image: an integer which specifies the number of masks
to sample. to sample.
cropped_mask_size: an integer which specifies the final cropped mask size mask_target_size: an integer which specifies the final cropped mask size
after sampling. The output masks are resized w.r.t the sampled RoIs. after sampling. The output masks are resized w.r.t the sampled RoIs.
Returns: Returns:
...@@ -253,7 +253,7 @@ def sample_and_crop_foreground_masks(candidate_rois, ...@@ -253,7 +253,7 @@ def sample_and_crop_foreground_masks(candidate_rois,
foreground_classes: a tensor of shape of [batch_size, K] storing the classes foreground_classes: a tensor of shape of [batch_size, K] storing the classes
corresponding to the sampled foreground masks. corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of cropoped_foreground_masks: a tensor of shape of
[batch_size, K, cropped_mask_size, cropped_mask_size] storing the cropped [batch_size, K, mask_target_size, mask_target_size] storing the cropped
foreground masks used for training. foreground masks used for training.
""" """
with tf.name_scope('sample_and_crop_foreground_masks'): with tf.name_scope('sample_and_crop_foreground_masks'):
...@@ -268,23 +268,25 @@ def sample_and_crop_foreground_masks(candidate_rois, ...@@ -268,23 +268,25 @@ def sample_and_crop_foreground_masks(candidate_rois,
gather_nd_instance_indices = tf.stack( gather_nd_instance_indices = tf.stack(
[batch_indices, fg_instance_indices], axis=-1) [batch_indices, fg_instance_indices], axis=-1)
foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices) foreground_rois = tf.gather_nd(
candidate_rois, gather_nd_instance_indices)
foreground_boxes = tf.gather_nd( foreground_boxes = tf.gather_nd(
candidate_gt_boxes, gather_nd_instance_indices) candidate_gt_boxes, gather_nd_instance_indices)
foreground_classes = tf.gather_nd( foreground_classes = tf.gather_nd(
candidate_gt_classes, gather_nd_instance_indices) candidate_gt_classes, gather_nd_instance_indices)
fg_gt_indices = tf.gather_nd( foreground_gt_indices = tf.gather_nd(
candidate_gt_indices, gather_nd_instance_indices) candidate_gt_indices, gather_nd_instance_indices)
fg_gt_indices_shape = tf.shape(fg_gt_indices) foreground_gt_indices_shape = tf.shape(foreground_gt_indices)
batch_indices = ( batch_indices = (
tf.expand_dims(tf.range(fg_gt_indices_shape[0]), axis=-1) * tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) *
tf.ones([1, fg_gt_indices_shape[-1]], dtype=tf.int32)) tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32))
gather_nd_gt_indices = tf.stack([batch_indices, fg_gt_indices], axis=-1) gather_nd_gt_indices = tf.stack(
[batch_indices, foreground_gt_indices], axis=-1)
foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices) foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices)
cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box( cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box(
foreground_masks, foreground_boxes, foreground_rois, cropped_mask_size) foreground_masks, foreground_boxes, foreground_rois, mask_target_size)
return foreground_rois, foreground_classes, cropped_foreground_masks return foreground_rois, foreground_classes, cropped_foreground_masks
...@@ -345,7 +347,7 @@ class MaskSampler(object): ...@@ -345,7 +347,7 @@ class MaskSampler(object):
def __init__(self, params): def __init__(self, params):
self._num_mask_samples_per_image = params.num_mask_samples_per_image self._num_mask_samples_per_image = params.num_mask_samples_per_image
self._cropped_mask_size = params.cropped_mask_size self._mask_target_size = params.mask_target_size
def __call__(self, def __call__(self,
candidate_rois, candidate_rois,
...@@ -381,7 +383,7 @@ class MaskSampler(object): ...@@ -381,7 +383,7 @@ class MaskSampler(object):
foreground_classes: a tensor of shape of [batch_size, K] storing the foreground_classes: a tensor of shape of [batch_size, K] storing the
classes corresponding to the sampled foreground masks. classes corresponding to the sampled foreground masks.
cropoped_foreground_masks: a tensor of shape of cropoped_foreground_masks: a tensor of shape of
[batch_size, K, cropped_mask_size, cropped_mask_size] storing the [batch_size, K, mask_target_size, mask_target_size] storing the
cropped foreground masks used for training. cropped foreground masks used for training.
""" """
foreground_rois, foreground_classes, cropped_foreground_masks = ( foreground_rois, foreground_classes, cropped_foreground_masks = (
...@@ -392,5 +394,5 @@ class MaskSampler(object): ...@@ -392,5 +394,5 @@ class MaskSampler(object):
candidate_gt_indices, candidate_gt_indices,
gt_masks, gt_masks,
self._num_mask_samples_per_image, self._num_mask_samples_per_image,
self._cropped_mask_size)) self._mask_target_size))
return foreground_rois, foreground_classes, cropped_foreground_masks return foreground_rois, foreground_classes, cropped_foreground_masks
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册