diff --git a/official/vision/detection/dataloader/maskrcnn_parser.py b/official/vision/detection/dataloader/maskrcnn_parser.py index 296d9cc9d28deddb18b21bbbaf9400c99d889246..1fe6c40d461d3afded99dda1795bf799a8f408ee 100644 --- a/official/vision/detection/dataloader/maskrcnn_parser.py +++ b/official/vision/detection/dataloader/maskrcnn_parser.py @@ -232,11 +232,6 @@ class Parser(object): offset = image_info[3, :] boxes = input_utils.resize_and_crop_boxes( boxes, image_scale, (image_height, image_width), offset) - if self._include_mask: - masks = input_utils.resize_and_crop_masks( - tf.expand_dims(masks, axis=-1), - image_scale, (image_height, image_width), offset) - masks = tf.squeeze(masks, axis=-1) # Filters out ground truth boxes that are all zeros. indices = input_utils.get_non_empty_box_indices(boxes) @@ -244,10 +239,14 @@ class Parser(object): classes = tf.gather(classes, indices) if self._include_mask: masks = tf.gather(masks, indices) + cropped_boxes = boxes + tf.cast( + tf.tile(tf.expand_dims(offset, axis=0), [1, 2]), dtype=tf.float32) + cropped_boxes = box_utils.normalize_boxes( + cropped_boxes, image_info[1, :]) num_masks = tf.shape(masks)[0] masks = tf.image.crop_and_resize( tf.expand_dims(masks, axis=-1), - box_utils.normalize_boxes(boxes, tf.shape(image)[0:2]), + cropped_boxes, box_indices=tf.range(num_masks, dtype=tf.int32), crop_size=[self._mask_crop_size, self._mask_crop_size], method='bilinear') diff --git a/official/vision/detection/modeling/architecture/factory.py b/official/vision/detection/modeling/architecture/factory.py index 50512faf5ab3e8d0df7ef8a6fe4e1bece394b77f..6020e361a48ff85f05a019821afd2104e0476d57 100644 --- a/official/vision/detection/modeling/architecture/factory.py +++ b/official/vision/detection/modeling/architecture/factory.py @@ -104,7 +104,7 @@ def fast_rcnn_head_generator(params): def mask_rcnn_head_generator(params): """Generator function for Mask R-CNN head architecture.""" return heads.MaskrcnnHead(params.num_classes, - params.mrcnn_resolution, + params.mask_target_size, batch_norm_relu=batch_norm_relu_generator( params.batch_norm)) diff --git a/official/vision/detection/modeling/architecture/heads.py b/official/vision/detection/modeling/architecture/heads.py index 17e2067d7c2530f2796194122c20575a28e294c5..0cfbef6f6e546423a56b88b5096efced20b5b24a 100644 --- a/official/vision/detection/modeling/architecture/heads.py +++ b/official/vision/detection/modeling/architecture/heads.py @@ -177,18 +177,18 @@ class MaskrcnnHead(object): def __init__(self, num_classes, - mrcnn_resolution, + mask_target_size, batch_norm_relu=nn_ops.BatchNormRelu): """Initialize params to build Fast R-CNN head. Args: num_classes: a integer for the number of classes. - mrcnn_resolution: a integer that is the resolution of masks. + mask_target_size: a integer that is the resolution of masks. batch_norm_relu: an operation that includes a batch normalization layer followed by a relu layer(optional). """ self._num_classes = num_classes - self._mrcnn_resolution = mrcnn_resolution + self._mask_target_size = mask_target_size self._batch_norm_relu = batch_norm_relu def __call__(self, roi_features, class_indices, is_training=None): @@ -272,7 +272,7 @@ class MaskrcnnHead(object): name='mask_fcn_logits')( net) mask_outputs = tf.reshape(mask_outputs, [ - -1, num_rois, self._mrcnn_resolution, self._mrcnn_resolution, + -1, num_rois, self._mask_target_size, self._mask_target_size, self._num_classes ]) diff --git a/official/vision/detection/ops/sampling_ops.py b/official/vision/detection/ops/sampling_ops.py index 94dbe7f319800d6f51bf41f650c229ed99d2411e..c380324cd0125d9ad98d1ed50723222e6ed02b02 100644 --- a/official/vision/detection/ops/sampling_ops.py +++ b/official/vision/detection/ops/sampling_ops.py @@ -220,8 +220,8 @@ def sample_and_crop_foreground_masks(candidate_rois, candidate_gt_classes, candidate_gt_indices, gt_masks, - num_mask_samples_per_image=28, - cropped_mask_size=28): + num_mask_samples_per_image=128, + mask_target_size=28): """Samples and creates cropped foreground masks for training. Args: @@ -243,7 +243,7 @@ def sample_and_crop_foreground_masks(candidate_rois, containing all the groundtruth masks which sample masks are drawn from. num_mask_samples_per_image: an integer which specifies the number of masks to sample. - cropped_mask_size: an integer which specifies the final cropped mask size + mask_target_size: an integer which specifies the final cropped mask size after sampling. The output masks are resized w.r.t the sampled RoIs. Returns: @@ -253,7 +253,7 @@ def sample_and_crop_foreground_masks(candidate_rois, foreground_classes: a tensor of shape of [batch_size, K] storing the classes corresponding to the sampled foreground masks. cropoped_foreground_masks: a tensor of shape of - [batch_size, K, cropped_mask_size, cropped_mask_size] storing the cropped + [batch_size, K, mask_target_size, mask_target_size] storing the cropped foreground masks used for training. """ with tf.name_scope('sample_and_crop_foreground_masks'): @@ -268,23 +268,25 @@ def sample_and_crop_foreground_masks(candidate_rois, gather_nd_instance_indices = tf.stack( [batch_indices, fg_instance_indices], axis=-1) - foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices) + foreground_rois = tf.gather_nd( + candidate_rois, gather_nd_instance_indices) foreground_boxes = tf.gather_nd( candidate_gt_boxes, gather_nd_instance_indices) foreground_classes = tf.gather_nd( candidate_gt_classes, gather_nd_instance_indices) - fg_gt_indices = tf.gather_nd( + foreground_gt_indices = tf.gather_nd( candidate_gt_indices, gather_nd_instance_indices) - fg_gt_indices_shape = tf.shape(fg_gt_indices) + foreground_gt_indices_shape = tf.shape(foreground_gt_indices) batch_indices = ( - tf.expand_dims(tf.range(fg_gt_indices_shape[0]), axis=-1) * - tf.ones([1, fg_gt_indices_shape[-1]], dtype=tf.int32)) - gather_nd_gt_indices = tf.stack([batch_indices, fg_gt_indices], axis=-1) + tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) * + tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32)) + gather_nd_gt_indices = tf.stack( + [batch_indices, foreground_gt_indices], axis=-1) foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices) cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box( - foreground_masks, foreground_boxes, foreground_rois, cropped_mask_size) + foreground_masks, foreground_boxes, foreground_rois, mask_target_size) return foreground_rois, foreground_classes, cropped_foreground_masks @@ -345,7 +347,7 @@ class MaskSampler(object): def __init__(self, params): self._num_mask_samples_per_image = params.num_mask_samples_per_image - self._cropped_mask_size = params.cropped_mask_size + self._mask_target_size = params.mask_target_size def __call__(self, candidate_rois, @@ -381,7 +383,7 @@ class MaskSampler(object): foreground_classes: a tensor of shape of [batch_size, K] storing the classes corresponding to the sampled foreground masks. cropoped_foreground_masks: a tensor of shape of - [batch_size, K, cropped_mask_size, cropped_mask_size] storing the + [batch_size, K, mask_target_size, mask_target_size] storing the cropped foreground masks used for training. """ foreground_rois, foreground_classes, cropped_foreground_masks = ( @@ -392,5 +394,5 @@ class MaskSampler(object): candidate_gt_indices, gt_masks, self._num_mask_samples_per_image, - self._cropped_mask_size)) + self._mask_target_size)) return foreground_rois, foreground_classes, cropped_foreground_masks