未验证 提交 59f7e80a 编写于 作者: P pkulzc 提交者: GitHub

Update object detection post processing and fixes boxes padding/clipping issue. (#5026)

* Merged commit includes the following changes:
207771702  by Zhichao Lu:

    Refactoring evaluation utilities so that it is easier to introduce new DetectionEvaluators with eval_metric_ops.

--
207758641  by Zhichao Lu:

    Require tensorflow version 1.9+ for running object detection API.

--
207641470  by Zhichao Lu:

    Clip `num_groundtruth_boxes` in pad_input_data_to_static_shapes() to `max_num_boxes`. This prevents a scenario where tensors are sliced to an invalid range in model_lib.unstack_batch().

--
207621728  by Zhichao Lu:

    This CL adds a FreezableBatchNorm that inherits from the Keras BatchNormalization layer, but supports freezing the `training` parameter at construction time instead of having to do it in the `call` method.

    It also adds a method to the `KerasLayerHyperparams` class that will build an appropriate FreezableBatchNorm layer according to the hyperparameter configuration. If batch_norm is disabled, this method returns and Identity layer.

    These will be used to simplify the conversion to Keras APIs.

--
207610524  by Zhichao Lu:

    Update anchor generators and box predictors for python3 compatibility.

--
207585122  by Zhichao Lu:

    Refactoring convolutional box predictor into separate prediction heads.

--
207549305  by Zhichao Lu:

    Pass all 1s for batch weights if nothing is specified in GT.

--
207336575  by Zhichao Lu:

    Move the new argument 'target_assigner_instance' to the end of the list of arguments to the ssd_meta_arch constructor for backwards compatibility.

--
207327862  by Zhichao Lu:

    Enable support for float output in quantized custom op for postprocessing in SSD Mobilenet model.

--
207323154  by Zhichao Lu:

    Bug fix: change dict.iteritems() to dict.items()

--
207301109  by Zhichao Lu:

    Integrating expected_classification_loss_under_sampling op as an option in the ssd_meta_arch

--
207286221  by Zhichao Lu:

    Adding an option to weight regression loss with foreground scores from the ground truth labels.

--
207231739  by Zhichao Lu:

    Explicitly mentioning the argument names when calling the batch target assigner.

--
207206356  by Zhichao Lu:

    Add include_trainable_variables field to train config to better handle trainable variables.

--
207135930  by Zhichao Lu:

    Internal change.

--
206862541  by Zhichao Lu:

    Do not unpad the outputs from batch_non_max_suppression before sampling.

    Since BalancedPositiveNegativeSampler takes an indicator for valid positions to sample from we can pass the output from NMS directly into Sampler.

--

PiperOrigin-RevId: 207771702

* Remove unused doc.
上级 fb6bc29b
......@@ -58,7 +58,7 @@ class MultiscaleGridAnchorGenerator(anchor_generator.AnchorGenerator):
self._normalize_coordinates = normalize_coordinates
scales = [2**(float(scale) / scales_per_octave)
for scale in xrange(scales_per_octave)]
for scale in range(scales_per_octave)]
aspects = list(aspect_ratios)
for level in range(min_level, max_level + 1):
......
......@@ -18,12 +18,280 @@
from object_detection.predictors import convolutional_box_predictor
from object_detection.predictors import mask_rcnn_box_predictor
from object_detection.predictors import rfcn_box_predictor
from object_detection.predictors.mask_rcnn_heads import box_head
from object_detection.predictors.mask_rcnn_heads import class_head
from object_detection.predictors.mask_rcnn_heads import mask_head
from object_detection.predictors.heads import box_head
from object_detection.predictors.heads import class_head
from object_detection.predictors.heads import mask_head
from object_detection.protos import box_predictor_pb2
def build_convolutional_box_predictor(
is_training,
num_classes,
conv_hyperparams_fn,
min_depth,
max_depth,
num_layers_before_predictor,
use_dropout,
dropout_keep_prob,
kernel_size,
box_code_size,
apply_sigmoid_to_scores=False,
class_prediction_bias_init=0.0,
use_depthwise=False,
predict_instance_masks=False,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=False):
"""Builds the ConvolutionalBoxPredictor from the arguments.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: Number of classes.
conv_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for convolution ops.
min_depth: Minimum feature depth prior to predicting box encodings
and class predictions.
max_depth: Maximum feature depth prior to predicting box encodings
and class predictions. If max_depth is set to 0, no additional
feature map will be inserted before location and class predictions.
num_layers_before_predictor: Number of the additional conv layers before
the predictor.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
kernel_size: Size of final convolution kernel. If the
spatial resolution of the feature map is smaller than the kernel size,
then the kernel size is automatically set to be
min(feature_width, feature_height).
box_code_size: Size of encoding for each box.
apply_sigmoid_to_scores: if True, apply the sigmoid on the output
class_predictions.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
use_depthwise: Whether to use depthwise convolutions for prediction
steps. Default is False.
predict_instance_masks: If True, will add a third stage mask prediction
to the returned class.
mask_height: Desired output mask height. The default value is 7.
mask_width: Desired output mask width. The default value is 7.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
Returns:
A ConvolutionalBoxPredictor class.
"""
box_prediction_head = box_head.ConvolutionalBoxHead(
is_training=is_training,
box_code_size=box_code_size,
kernel_size=kernel_size,
use_depthwise=use_depthwise)
class_prediction_head = class_head.ConvolutionalClassHead(
is_training=is_training,
num_classes=num_classes,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob,
kernel_size=kernel_size,
apply_sigmoid_to_scores=apply_sigmoid_to_scores,
class_prediction_bias_init=class_prediction_bias_init,
use_depthwise=use_depthwise)
other_heads = {}
if predict_instance_masks:
other_heads[convolutional_box_predictor.MASK_PREDICTIONS] = (
mask_head.ConvolutionalMaskHead(
is_training=is_training,
num_classes=num_classes,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob,
kernel_size=kernel_size,
use_depthwise=use_depthwise,
mask_height=mask_height,
mask_width=mask_width,
masks_are_class_agnostic=masks_are_class_agnostic))
return convolutional_box_predictor.ConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
other_heads=other_heads,
conv_hyperparams_fn=conv_hyperparams_fn,
num_layers_before_predictor=num_layers_before_predictor,
min_depth=min_depth,
max_depth=max_depth)
def build_weight_shared_convolutional_box_predictor(
is_training,
num_classes,
conv_hyperparams_fn,
depth,
num_layers_before_predictor,
box_code_size,
kernel_size=3,
class_prediction_bias_init=0.0,
use_dropout=False,
dropout_keep_prob=0.8,
share_prediction_tower=False,
apply_batch_norm=True,
predict_instance_masks=False,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=False):
"""Builds and returns a WeightSharedConvolutionalBoxPredictor class.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
conv_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for convolution ops.
depth: depth of conv layers.
num_layers_before_predictor: Number of the additional conv layers before
the predictor.
box_code_size: Size of encoding for each box.
kernel_size: Size of final convolution kernel.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
use_dropout: Whether to apply dropout to class prediction head.
dropout_keep_prob: Probability of keeping activiations.
share_prediction_tower: Whether to share the multi-layer tower between box
prediction and class prediction heads.
apply_batch_norm: Whether to apply batch normalization to conv layers in
this predictor.
predict_instance_masks: If True, will add a third stage mask prediction
to the returned class.
mask_height: Desired output mask height. The default value is 7.
mask_width: Desired output mask width. The default value is 7.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
Returns:
A WeightSharedConvolutionalBoxPredictor class.
"""
box_prediction_head = box_head.WeightSharedConvolutionalBoxHead(
box_code_size=box_code_size,
kernel_size=kernel_size,
class_prediction_bias_init=class_prediction_bias_init)
class_prediction_head = (
class_head.WeightSharedConvolutionalClassHead(
num_classes=num_classes,
kernel_size=kernel_size,
class_prediction_bias_init=class_prediction_bias_init,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob))
other_heads = {}
if predict_instance_masks:
other_heads[convolutional_box_predictor.MASK_PREDICTIONS] = (
mask_head.WeightSharedConvolutionalMaskHead(
num_classes=num_classes,
kernel_size=kernel_size,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob,
mask_height=mask_height,
mask_width=mask_width,
masks_are_class_agnostic=masks_are_class_agnostic))
return convolutional_box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
other_heads=other_heads,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=depth,
num_layers_before_predictor=num_layers_before_predictor,
kernel_size=kernel_size,
apply_batch_norm=apply_batch_norm,
share_prediction_tower=share_prediction_tower)
def build_mask_rcnn_box_predictor(is_training,
num_classes,
fc_hyperparams_fn,
use_dropout,
dropout_keep_prob,
box_code_size,
share_box_across_classes=False,
predict_instance_masks=False,
conv_hyperparams_fn=None,
mask_height=14,
mask_width=14,
mask_prediction_num_conv_layers=2,
mask_prediction_conv_depth=256,
masks_are_class_agnostic=False):
"""Builds and returns a MaskRCNNBoxPredictor class.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
fc_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for fully connected ops.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
box_code_size: Size of encoding for each box.
share_box_across_classes: Whether to share boxes across classes rather
than use a different box for each class.
predict_instance_masks: If True, will add a third stage mask prediction
to the returned class.
conv_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for convolution ops.
mask_height: Desired output mask height. The default value is 14.
mask_width: Desired output mask width. The default value is 14.
mask_prediction_num_conv_layers: Number of convolution layers applied to
the image_features in mask prediction branch.
mask_prediction_conv_depth: The depth for the first conv2d_transpose op
applied to the image_features in the mask prediction branch. If set
to 0, the depth of the convolution layers will be automatically chosen
based on the number of object classes and the number of channels in the
image features.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
Returns:
A MaskRCNNBoxPredictor class.
"""
box_prediction_head = box_head.MaskRCNNBoxHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob,
box_code_size=box_code_size,
share_box_across_classes=share_box_across_classes)
class_prediction_head = class_head.MaskRCNNClassHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob)
third_stage_heads = {}
if predict_instance_masks:
third_stage_heads[
mask_rcnn_box_predictor.
MASK_PREDICTIONS] = mask_head.MaskRCNNMaskHead(
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
mask_height=mask_height,
mask_width=mask_width,
mask_prediction_num_conv_layers=mask_prediction_num_conv_layers,
mask_prediction_conv_depth=mask_prediction_conv_depth,
masks_are_class_agnostic=masks_are_class_agnostic)
return mask_rcnn_box_predictor.MaskRCNNBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
third_stage_heads=third_stage_heads)
def build(argscope_fn, box_predictor_config, is_training, num_classes):
"""Builds box predictor based on the configuration.
......@@ -56,25 +324,22 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
config_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = (
convolutional_box_predictor.ConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
min_depth=config_box_predictor.min_depth,
max_depth=config_box_predictor.max_depth,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
kernel_size=config_box_predictor.kernel_size,
box_code_size=config_box_predictor.box_code_size,
apply_sigmoid_to_scores=config_box_predictor.
apply_sigmoid_to_scores,
class_prediction_bias_init=(
config_box_predictor.class_prediction_bias_init),
use_depthwise=config_box_predictor.use_depthwise))
return box_predictor_object
return build_convolutional_box_predictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
box_code_size=config_box_predictor.box_code_size,
kernel_size=config_box_predictor.kernel_size,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
min_depth=config_box_predictor.min_depth,
max_depth=config_box_predictor.max_depth,
apply_sigmoid_to_scores=config_box_predictor.apply_sigmoid_to_scores,
class_prediction_bias_init=(
config_box_predictor.class_prediction_bias_init),
use_depthwise=config_box_predictor.use_depthwise)
if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
config_box_predictor = (
......@@ -83,23 +348,21 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
is_training)
apply_batch_norm = config_box_predictor.conv_hyperparams.HasField(
'batch_norm')
box_predictor_object = (
convolutional_box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=config_box_predictor.depth,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
kernel_size=config_box_predictor.kernel_size,
box_code_size=config_box_predictor.box_code_size,
class_prediction_bias_init=config_box_predictor.
class_prediction_bias_init,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
share_prediction_tower=config_box_predictor.share_prediction_tower,
apply_batch_norm=apply_batch_norm))
return box_predictor_object
return build_weight_shared_convolutional_box_predictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=config_box_predictor.depth,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
box_code_size=config_box_predictor.box_code_size,
kernel_size=config_box_predictor.kernel_size,
class_prediction_bias_init=(
config_box_predictor.class_prediction_bias_init),
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
share_prediction_tower=config_box_predictor.share_prediction_tower,
apply_batch_norm=apply_batch_norm)
if box_predictor_oneof == 'mask_rcnn_box_predictor':
config_box_predictor = box_predictor_config.mask_rcnn_box_predictor
......@@ -109,7 +372,7 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if config_box_predictor.HasField('conv_hyperparams'):
conv_hyperparams_fn = argscope_fn(
config_box_predictor.conv_hyperparams, is_training)
box_prediction_head = box_head.BoxHead(
return build_mask_rcnn_box_predictor(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
......@@ -117,34 +380,17 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
box_code_size=config_box_predictor.box_code_size,
share_box_across_classes=(
config_box_predictor.share_box_across_classes))
class_prediction_head = class_head.ClassHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability)
third_stage_heads = {}
if config_box_predictor.predict_instance_masks:
third_stage_heads[
mask_rcnn_box_predictor.MASK_PREDICTIONS] = mask_head.MaskHead(
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
mask_height=config_box_predictor.mask_height,
mask_width=config_box_predictor.mask_width,
mask_prediction_num_conv_layers=(
config_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=(
config_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
config_box_predictor.masks_are_class_agnostic))
box_predictor_object = mask_rcnn_box_predictor.MaskRCNNBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
third_stage_heads=third_stage_heads)
return box_predictor_object
config_box_predictor.share_box_across_classes),
predict_instance_masks=config_box_predictor.predict_instance_masks,
conv_hyperparams_fn=conv_hyperparams_fn,
mask_height=config_box_predictor.mask_height,
mask_width=config_box_predictor.mask_width,
mask_prediction_num_conv_layers=(
config_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=(
config_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
config_box_predictor.masks_are_class_agnostic))
if box_predictor_oneof == 'rfcn_box_predictor':
config_box_predictor = box_predictor_config.rfcn_box_predictor
......
......@@ -111,16 +111,17 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto,
is_training=False,
num_classes=10)
class_head = box_predictor._class_prediction_head
self.assertEqual(box_predictor._min_depth, 2)
self.assertEqual(box_predictor._max_depth, 16)
self.assertEqual(box_predictor._num_layers_before_predictor, 2)
self.assertFalse(box_predictor._use_dropout)
self.assertAlmostEqual(box_predictor._dropout_keep_prob, 0.4)
self.assertTrue(box_predictor._apply_sigmoid_to_scores)
self.assertAlmostEqual(box_predictor._class_prediction_bias_init, 4.0)
self.assertFalse(class_head._use_dropout)
self.assertAlmostEqual(class_head._dropout_keep_prob, 0.4)
self.assertTrue(class_head._apply_sigmoid_to_scores)
self.assertAlmostEqual(class_head._class_prediction_bias_init, 4.0)
self.assertEqual(box_predictor.num_classes, 10)
self.assertFalse(box_predictor._is_training)
self.assertTrue(box_predictor._use_depthwise)
self.assertTrue(class_head._use_depthwise)
def test_construct_default_conv_box_predictor(self):
box_predictor_text_proto = """
......@@ -143,15 +144,16 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto,
is_training=True,
num_classes=90)
class_head = box_predictor._class_prediction_head
self.assertEqual(box_predictor._min_depth, 0)
self.assertEqual(box_predictor._max_depth, 0)
self.assertEqual(box_predictor._num_layers_before_predictor, 0)
self.assertTrue(box_predictor._use_dropout)
self.assertAlmostEqual(box_predictor._dropout_keep_prob, 0.8)
self.assertFalse(box_predictor._apply_sigmoid_to_scores)
self.assertTrue(class_head._use_dropout)
self.assertAlmostEqual(class_head._dropout_keep_prob, 0.8)
self.assertFalse(class_head._apply_sigmoid_to_scores)
self.assertEqual(box_predictor.num_classes, 90)
self.assertTrue(box_predictor._is_training)
self.assertFalse(box_predictor._use_depthwise)
self.assertFalse(class_head._use_depthwise)
class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
......@@ -235,12 +237,13 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto,
is_training=False,
num_classes=10)
class_head = box_predictor._class_prediction_head
self.assertEqual(box_predictor._depth, 2)
self.assertEqual(box_predictor._num_layers_before_predictor, 2)
self.assertAlmostEqual(box_predictor._class_prediction_bias_init, 4.0)
self.assertEqual(box_predictor._apply_batch_norm, False)
self.assertAlmostEqual(class_head._class_prediction_bias_init, 4.0)
self.assertEqual(box_predictor.num_classes, 10)
self.assertFalse(box_predictor._is_training)
self.assertEqual(box_predictor._apply_batch_norm, False)
def test_construct_default_conv_box_predictor(self):
box_predictor_text_proto = """
......
......@@ -16,6 +16,7 @@
"""Builder function to construct tf-slim arg_scope for convolution, fc ops."""
import tensorflow as tf
from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2
from object_detection.utils import context_manager
......@@ -93,6 +94,38 @@ class KerasLayerHyperparams(object):
new_batch_norm_params.update(overrides)
return new_batch_norm_params
def build_batch_norm(self, training=None, **overrides):
"""Returns a Batch Normalization layer with the appropriate hyperparams.
If the hyperparams are configured to not use batch normalization,
this will return a Keras Lambda layer that only applies tf.Identity,
without doing any normalization.
Optionally overrides values in the batch_norm hyperparam dict. Overrides
only apply to individual calls of this method, and do not affect
future calls.
Args:
training: if True, the normalization layer will normalize using the batch
statistics. If False, the normalization layer will be frozen and will
act as if it is being used for inference. If None, the layer
will look up the Keras learning phase at `call` time to decide what to
do.
**overrides: batch normalization construction args to override from the
batch_norm hyperparams dictionary.
Returns: Either a FreezableBatchNorm layer (if use_batch_norm() is True),
or a Keras Lambda layer that applies the identity (if use_batch_norm()
is False)
"""
if self.use_batch_norm():
return freezable_batch_norm.FreezableBatchNorm(
training=training,
**self.batch_norm_params(**overrides)
)
else:
return tf.keras.layers.Lambda(tf.identity)
def params(self, **overrides):
"""Returns a dict containing the layer construction hyperparameters to use.
......
......@@ -21,6 +21,7 @@ import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim
......@@ -282,6 +283,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
batch_norm_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm))
def test_return_non_default_batch_norm_params_keras_override(
self):
conv_hyperparams_text_proto = """
......@@ -413,6 +418,11 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertFalse(keras_config.use_batch_norm())
self.assertEqual(keras_config.batch_norm_params(), {})
# The batch norm builder should build an identity Lambda layer
identity_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(identity_layer,
tf.keras.layers.Lambda))
def test_use_none_activation(self):
conv_hyperparams_text_proto = """
regularizer {
......
......@@ -14,6 +14,7 @@
# ==============================================================================
"""A function to build a DetectionModel from configuration."""
import functools
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
......@@ -44,6 +45,8 @@ from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMo
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.predictors import rfcn_box_predictor
from object_detection.protos import model_pb2
from object_detection.utils import ops
# A map of names to SSD feature extractors.
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
......@@ -220,6 +223,22 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
random_example_sampler) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
weight_regression_loss_by_score = (ssd_config.weight_regression_loss_by_score)
target_assigner_instance = target_assigner.TargetAssigner(
region_similarity_calculator,
matcher,
box_coder,
negative_class_weight=negative_class_weight,
weight_regression_loss_by_score=weight_regression_loss_by_score)
expected_classification_loss_under_sampling = None
if ssd_config.use_expected_classification_loss_under_sampling:
expected_classification_loss_under_sampling = functools.partial(
ops.expected_classification_loss_under_sampling,
minimum_negative_sampling=ssd_config.minimum_negative_sampling,
desired_negative_sampling_ratio=ssd_config.
desired_negative_sampling_ratio)
return ssd_meta_arch.SSDMetaArch(
is_training,
......@@ -240,12 +259,15 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
localization_weight,
normalize_loss_by_num_matches,
hard_example_miner,
target_assigner_instance=target_assigner_instance,
add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
random_example_sampler=random_example_sampler,
expected_classification_loss_under_sampling=
expected_classification_loss_under_sampling)
def _build_faster_rcnn_feature_extractor(
......
......@@ -144,6 +144,9 @@ class ModelBuilderTest(tf.test.TestCase):
}
}
}
use_expected_classification_loss_under_sampling: true
minimum_negative_sampling: 10
desired_negative_sampling_ratio: 2
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
......@@ -151,6 +154,12 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
SSDInceptionV2FeatureExtractor)
self.assertIsNotNone(model._expected_classification_loss_under_sampling)
self.assertEqual(
model._expected_classification_loss_under_sampling.keywords, {
'minimum_negative_sampling': 10,
'desired_negative_sampling_ratio': 2
})
def test_create_ssd_inception_v3_model_from_config(self):
model_text_proto = """
......@@ -692,6 +701,7 @@ class ModelBuilderTest(tf.test.TestCase):
}
}
}
weight_regression_loss_by_score: true
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
......@@ -700,6 +710,7 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model._feature_extractor,
SSDMobileNetV2FeatureExtractor)
self.assertTrue(model._normalize_loc_loss_by_codesize)
self.assertTrue(model._target_assigner._weight_regression_loss_by_score)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A freezable batch norm layer that uses Keras batch normalization."""
import tensorflow as tf
class FreezableBatchNorm(tf.keras.layers.BatchNormalization):
"""Batch normalization layer (Ioffe and Szegedy, 2014).
This is a `freezable` batch norm layer that supports setting the `training`
parameter in the __init__ method rather than having to set it either via
the Keras learning phase or via the `call` method parameter. This layer will
forward all other parameters to the default Keras `BatchNormalization`
layer
This is class is necessary because Object Detection model training sometimes
requires batch normalization layers to be `frozen` and used as if it was
evaluation time, despite still training (and potentially using dropout layers)
Like the default Keras BatchNormalization layer, this will normalize the
activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
Arguments:
training: Boolean or None. If True, the batch normalization layer will
normalize the input batch using the batch mean and standard deviation,
and update the total moving mean and standard deviations. If False, the
layer will normalize using the moving average and std. dev, without
updating the learned avg and std. dev.
If None, the layer will follow the keras BatchNormalization layer
strategy of checking the Keras learning phase at `call` time to decide
what to do.
**kwargs: The keyword arguments to forward to the keras BatchNormalization
layer constructor.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
References:
- [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def __init__(self, training=None, **kwargs):
super(FreezableBatchNorm, self).__init__(**kwargs)
self._training = training
def call(self, inputs, training=None):
if training is None:
training = self._training
return super(FreezableBatchNorm, self).call(inputs, training=training)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.freezable_batch_norm."""
import numpy as np
import tensorflow as tf
from object_detection.core import freezable_batch_norm
class FreezableBatchNormTest(tf.test.TestCase):
"""Tests for FreezableBatchNorm operations."""
def _build_model(self, training=None):
model = tf.keras.models.Sequential()
norm = freezable_batch_norm.FreezableBatchNorm(training=training,
input_shape=(10,),
momentum=0.8)
model.add(norm)
return model, norm
def _train_freezable_batch_norm(self, training_mean, training_var):
model, _ = self._build_model()
model.compile(loss='mse', optimizer='sgd')
# centered on training_mean, variance training_var
train_data = np.random.normal(
loc=training_mean,
scale=training_var,
size=(1000, 10))
model.fit(train_data, train_data, epochs=4, verbose=0)
return model.weights
def test_batchnorm_freezing_training_true(self):
with self.test_session():
training_mean = 5.0
training_var = 10.0
testing_mean = -10.0
testing_var = 5.0
# Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean,
training_var)
# Load the batch norm weights, freezing training to True.
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the batch statistics.
model, norm = self._build_model(training=True)
for trained_weight, blank_weight in zip(trained_weights, model.weights):
weight_copy = blank_weight.assign(tf.keras.backend.eval(trained_weight))
tf.keras.backend.eval(weight_copy)
# centered on testing_mean, variance testing_var
test_data = np.random.normal(
loc=testing_mean,
scale=testing_var,
size=(1000, 10))
out_tensor = norm(tf.convert_to_tensor(test_data, dtype=tf.float32))
out = tf.keras.backend.eval(out_tensor)
out -= tf.keras.backend.eval(norm.beta)
out /= tf.keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1.5e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1.5e-1)
def test_batchnorm_freezing_training_false(self):
with self.test_session():
training_mean = 5.0
training_var = 10.0
testing_mean = -10.0
testing_var = 5.0
# Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean,
training_var)
# Load the batch norm back up, freezing training to False.
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the training data's statistics.
model, norm = self._build_model(training=False)
for trained_weight, blank_weight in zip(trained_weights, model.weights):
weight_copy = blank_weight.assign(tf.keras.backend.eval(trained_weight))
tf.keras.backend.eval(weight_copy)
# centered on testing_mean, variance testing_var
test_data = np.random.normal(
loc=testing_mean,
scale=testing_var,
size=(1000, 10))
out_tensor = norm(tf.convert_to_tensor(test_data, dtype=tf.float32))
out = tf.keras.backend.eval(out_tensor)
out -= tf.keras.backend.eval(norm.beta)
out /= tf.keras.backend.eval(norm.gamma)
out *= training_var
out += (training_mean - testing_mean)
out /= testing_var
np.testing.assert_allclose(out.mean(), 0.0, atol=1.5e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1.5e-1)
if __name__ == '__main__':
tf.test.main()
......@@ -47,6 +47,9 @@ def multiclass_non_max_suppression(boxes,
Please note that this operation is performed on *all* classes, therefore any
background classes should be removed prior to calling this function.
Selected boxes are guaranteed to be sorted in decreasing order by score (but
the sort is not guaranteed to be stable).
Args:
boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
number of classes or 1 depending on whether a separate box is predicted
......@@ -106,15 +109,9 @@ def multiclass_non_max_suppression(boxes,
'must be specified.')
with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
num_boxes = tf.shape(boxes)[0]
num_scores = tf.shape(scores)[0]
num_classes = scores.get_shape()[1]
length_assert = tf.Assert(
tf.equal(num_boxes, num_scores),
['Incorrect scores field length: actual vs expected.',
num_scores, num_boxes])
selected_boxes_list = []
per_class_boxes_list = tf.unstack(boxes, axis=1)
if masks is not None:
......@@ -126,9 +123,9 @@ def multiclass_non_max_suppression(boxes,
for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
per_class_boxes = per_class_boxes_list[boxes_idx]
boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
with tf.control_dependencies([length_assert]):
class_scores = tf.reshape(
tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])
class_scores = tf.reshape(
tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])
boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
class_scores)
if masks is not None:
......@@ -142,22 +139,17 @@ def multiclass_non_max_suppression(boxes,
if additional_fields is not None:
for key, tensor in additional_fields.items():
boxlist_and_class_scores.add_field(key, tensor)
boxlist_filtered = box_list_ops.filter_greater_than(
boxlist_and_class_scores, score_thresh)
if clip_window is not None:
boxlist_filtered = box_list_ops.clip_to_window(
boxlist_filtered, clip_window)
if change_coordinate_frame:
boxlist_filtered = box_list_ops.change_coordinate_frame(
boxlist_filtered, clip_window)
max_selection_size = tf.minimum(max_size_per_class,
boxlist_filtered.num_boxes())
boxlist_and_class_scores.num_boxes())
selected_indices = tf.image.non_max_suppression(
boxlist_filtered.get(),
boxlist_filtered.get_field(fields.BoxListFields.scores),
boxlist_and_class_scores.get(),
boxlist_and_class_scores.get_field(fields.BoxListFields.scores),
max_selection_size,
iou_threshold=iou_thresh)
nms_result = box_list_ops.gather(boxlist_filtered, selected_indices)
iou_threshold=iou_thresh,
score_threshold=score_thresh)
nms_result = box_list_ops.gather(boxlist_and_class_scores,
selected_indices)
nms_result.add_field(
fields.BoxListFields.classes, (tf.zeros_like(
nms_result.get_field(fields.BoxListFields.scores)) + class_idx))
......@@ -165,6 +157,11 @@ def multiclass_non_max_suppression(boxes,
selected_boxes = box_list_ops.concatenate(selected_boxes_list)
sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
fields.BoxListFields.scores)
if clip_window is not None:
sorted_boxes = box_list_ops.clip_to_window(sorted_boxes, clip_window)
if change_coordinate_frame:
sorted_boxes = box_list_ops.change_coordinate_frame(
sorted_boxes, clip_window)
if max_total_size:
max_total_size = tf.minimum(max_total_size,
sorted_boxes.num_boxes())
......
......@@ -22,24 +22,6 @@ from object_detection.core import standard_fields as fields
class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
def test_with_invalid_scores_size(self):
boxes = tf.constant([[[0, 0, 1, 1]],
[[0, 0.1, 1, 1.1]],
[[0, -0.1, 1, 0.9]],
[[0, 10, 1, 11]],
[[0, 10.1, 1, 11.1]],
[[0, 100, 1, 101]]], tf.float32)
scores = tf.constant([[.9], [.75], [.6], [.95], [.5]])
iou_thresh = .5
score_thresh = 0.6
max_output_size = 3
nms = post_processing.multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, max_output_size)
with self.test_session() as sess:
with self.assertRaisesWithPredicateMatch(
tf.errors.InvalidArgumentError, 'Incorrect scores field length'):
sess.run(nms.get())
def test_multiclass_nms_select_with_shared_boxes(self):
boxes = tf.constant([[[0, 0, 1, 1]],
[[0, 0.1, 1, 1.1]],
......
......@@ -48,8 +48,12 @@ from object_detection.utils import shape_utils
class TargetAssigner(object):
"""Target assigner to compute classification and regression targets."""
def __init__(self, similarity_calc, matcher, box_coder,
negative_class_weight=1.0):
def __init__(self,
similarity_calc,
matcher,
box_coder,
negative_class_weight=1.0,
weight_regression_loss_by_score=False):
"""Construct Object Detection Target Assigner.
Args:
......@@ -60,6 +64,8 @@ class TargetAssigner(object):
groundtruth boxes with respect to anchors.
negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.].
weight_regression_loss_by_score: Whether to weight the regression loss by
ground truth box score.
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
......@@ -75,14 +81,20 @@ class TargetAssigner(object):
self._matcher = matcher
self._box_coder = box_coder
self._negative_class_weight = negative_class_weight
self._weight_regression_loss_by_score = weight_regression_loss_by_score
@property
def box_coder(self):
return self._box_coder
# TODO(rathodv): move labels, scores, and weights to groundtruth_boxes fields.
def assign(self, anchors, groundtruth_boxes, groundtruth_labels=None,
unmatched_class_label=None, groundtruth_weights=None, **params):
def assign(self,
anchors,
groundtruth_boxes,
groundtruth_labels=None,
unmatched_class_label=None,
groundtruth_weights=None,
**params):
"""Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors
......@@ -172,7 +184,13 @@ class TargetAssigner(object):
cls_targets = self._create_classification_targets(groundtruth_labels,
unmatched_class_label,
match)
reg_weights = self._create_regression_weights(match, groundtruth_weights)
if self._weight_regression_loss_by_score:
reg_weights = self._create_regression_weights(
match, groundtruth_weights * scores)
else:
reg_weights = self._create_regression_weights(match,
groundtruth_weights)
cls_weights = self._create_classification_weights(match,
groundtruth_weights)
......@@ -458,9 +476,9 @@ def batch_assign_targets(target_assigner,
gt_weights_batch = [None] * len(gt_class_targets_batch)
for anchors, gt_boxes, gt_class_targets, gt_weights in zip(
anchors_batch, gt_box_batch, gt_class_targets_batch, gt_weights_batch):
(cls_targets, cls_weights, reg_targets, reg_weights,
match) = target_assigner.assign(anchors, gt_boxes, gt_class_targets,
unmatched_class_label, gt_weights)
(cls_targets, cls_weights,
reg_targets, reg_weights, match) = target_assigner.assign(
anchors, gt_boxes, gt_class_targets, unmatched_class_label, gt_weights)
cls_targets_list.append(cls_targets)
cls_weights_list.append(cls_weights)
reg_targets_list.append(reg_targets)
......
......@@ -318,6 +318,50 @@ class TargetAssignerTest(test_case.TestCase):
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_weights_out, exp_reg_weights)
def test_assign_multiclass_with_weight_regression_loss_by_score(self):
def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels):
similarity_calc = region_similarity_calculator.IouSimilarity()
matcher = argmax_matcher.ArgMaxMatcher(
matched_threshold=0.5, unmatched_threshold=0.5)
box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
unmatched_class_label = tf.constant([1, 0, 0, 0, 0, 0, 0], tf.float32)
target_assigner = targetassigner.TargetAssigner(
similarity_calc,
matcher,
box_coder,
weight_regression_loss_by_score=True)
anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = target_assigner.assign(
anchors_boxlist,
groundtruth_boxlist,
groundtruth_labels,
unmatched_class_label=unmatched_class_label)
(_, cls_weights, _, reg_weights, _) = result
return (cls_weights, reg_weights)
anchor_means = np.array(
[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8], [0, 0.5, .5, 1.0],
[.75, 0, 1.0, .25]],
dtype=np.float32)
groundtruth_box_corners = np.array(
[[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.9, 0.9], [.75, 0, .95, .27]],
dtype=np.float32)
groundtruth_labels = np.array(
[[.9, .1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0],
[.5, 0, 0, .5, 0, 0, 0]],
dtype=np.float32)
exp_cls_weights = [1, 1, 1, 1] # background class gets weight of 1.
exp_reg_weights = [.1, 1, 0., .5] # background class gets weight of 0.
(cls_weights_out, reg_weights_out) = self.execute(
graph_fn, [anchor_means, groundtruth_box_corners, groundtruth_labels])
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_weights_out, exp_reg_weights)
def test_assign_multidimensional_class_targets(self):
def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels):
......
......@@ -32,6 +32,18 @@ from object_detection.utils import visualization_utils as vis_utils
slim = tf.contrib.slim
# A dictionary of metric names to classes that implement the metric. The classes
# in the dictionary must implement
# utils.object_detection_evaluation.DetectionEvaluator interface.
EVAL_METRICS_CLASS_DICT = {
'coco_detection_metrics':
coco_evaluation.CocoDetectionEvaluator,
'coco_mask_metrics':
coco_evaluation.CocoMaskEvaluator,
}
EVAL_DEFAULT_METRIC = 'coco_detection_metrics'
def write_metrics(metrics, global_step, summary_dir):
"""Write metrics to a summary directory.
......@@ -582,70 +594,90 @@ def result_dict_for_single_example(image,
return output_dict
def get_eval_metric_ops_for_evaluators(evaluation_metrics,
def get_evaluators(eval_config, categories, evaluator_options=None):
"""Returns the evaluator class according to eval_config, valid for categories.
Args:
eval_config: An `eval_pb2.EvalConfig`.
categories: A list of dicts, each of which has the following keys -
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
evaluator_options: A dictionary of metric names (see
EVAL_METRICS_CLASS_DICT) to `DetectionEvaluator` initialization
keyword arguments. For example:
evalator_options = {
'coco_detection_metrics': {'include_metrics_per_category': True}
}
Returns:
An list of instances of DetectionEvaluator.
Raises:
ValueError: if metric is not in the metric class dictionary.
"""
evaluator_options = evaluator_options or {}
eval_metric_fn_keys = eval_config.metrics_set
if not eval_metric_fn_keys:
eval_metric_fn_keys = [EVAL_DEFAULT_METRIC]
evaluators_list = []
for eval_metric_fn_key in eval_metric_fn_keys:
if eval_metric_fn_key not in EVAL_METRICS_CLASS_DICT:
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
kwargs_dict = (evaluator_options[eval_metric_fn_key] if eval_metric_fn_key
in evaluator_options else {})
evaluators_list.append(EVAL_METRICS_CLASS_DICT[eval_metric_fn_key](
categories,
**kwargs_dict))
return evaluators_list
def get_eval_metric_ops_for_evaluators(eval_config,
categories,
eval_dict,
include_metrics_per_category=False):
"""Returns a dictionary of eval metric ops to use with `tf.EstimatorSpec`.
eval_dict):
"""Returns eval metrics ops to use with `tf.estimator.EstimatorSpec`.
Args:
evaluation_metrics: List of evaluation metric names. Current options are
'coco_detection_metrics' and 'coco_mask_metrics'.
eval_config: An `eval_pb2.EvalConfig`.
categories: A list of dicts, each of which has the following keys -
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
eval_dict: An evaluation dictionary, returned from
result_dict_for_single_example().
include_metrics_per_category: If True, additionally include per-category
metrics.
Returns:
A dictionary of metric names to tuple of value_op and update_op that can be
used as eval metric ops in tf.EstimatorSpec.
Raises:
ValueError: If any of the metrics in `evaluation_metric` is not
'coco_detection_metrics' or 'coco_mask_metrics'.
"""
evaluation_metrics = list(set(evaluation_metrics))
input_data_fields = fields.InputDataFields
detection_fields = fields.DetectionResultFields
eval_metric_ops = {}
for metric in evaluation_metrics:
if metric == 'coco_detection_metrics':
coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
categories, include_metrics_per_category=include_metrics_per_category)
eval_metric_ops.update(
coco_evaluator.get_estimator_eval_metric_ops(
image_id=eval_dict[input_data_fields.key],
groundtruth_boxes=eval_dict[input_data_fields.groundtruth_boxes],
groundtruth_classes=eval_dict[
input_data_fields.groundtruth_classes],
detection_boxes=eval_dict[detection_fields.detection_boxes],
detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd)))
elif metric == 'coco_mask_metrics':
coco_mask_evaluator = coco_evaluation.CocoMaskEvaluator(
categories, include_metrics_per_category=include_metrics_per_category)
eval_metric_ops.update(
coco_mask_evaluator.get_estimator_eval_metric_ops(
image_id=eval_dict[input_data_fields.key],
groundtruth_boxes=eval_dict[input_data_fields.groundtruth_boxes],
groundtruth_classes=eval_dict[
input_data_fields.groundtruth_classes],
groundtruth_instance_masks=eval_dict[
input_data_fields.groundtruth_instance_masks],
detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes],
detection_masks=eval_dict[detection_fields.detection_masks],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd),))
else:
raise ValueError('The only evaluation metrics supported are '
'"coco_detection_metrics" and "coco_mask_metrics". '
'Found {} in the evaluation metrics'.format(metric))
evaluator_options = evaluator_options_from_eval_config(eval_config)
evaluators_list = get_evaluators(eval_config, categories, evaluator_options)
for evaluator in evaluators_list:
eval_metric_ops.update(evaluator.get_estimator_eval_metric_ops(
eval_dict))
return eval_metric_ops
def evaluator_options_from_eval_config(eval_config):
"""Produces a dictionary of evaluation options for each eval metric.
Args:
eval_config: An `eval_pb2.EvalConfig`.
Returns:
evaluator_options: A dictionary of metric names (see
EVAL_METRICS_CLASS_DICT) to `DetectionEvaluator` initialization
keyword arguments. For example:
evalator_options = {
'coco_detection_metrics': {'include_metrics_per_category': True}
}
"""
eval_metric_fn_keys = eval_config.metrics_set
evaluator_options = {}
for eval_metric_fn_key in eval_metric_fn_keys:
if eval_metric_fn_key in ('coco_detection_metrics', 'coco_mask_metrics'):
evaluator_options[eval_metric_fn_key] = {
'include_metrics_per_category': (
eval_config.include_metrics_per_category)
}
return evaluator_options
......@@ -23,6 +23,7 @@ import tensorflow as tf
from object_detection import eval_util
from object_detection.core import standard_fields as fields
from object_detection.protos import eval_pb2
class EvalUtilTest(tf.test.TestCase):
......@@ -64,11 +65,12 @@ class EvalUtilTest(tf.test.TestCase):
groundtruth)
def test_get_eval_metric_ops_for_coco_detections(self):
evaluation_metrics = ['coco_detection_metrics']
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(['coco_detection_metrics'])
categories = self._get_categories_list()
eval_dict = self._make_evaluation_dict()
metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
evaluation_metrics, categories, eval_dict)
eval_config, categories, eval_dict)
_, update_op = metric_ops['DetectionBoxes_Precision/mAP']
with self.test_session() as sess:
......@@ -82,12 +84,13 @@ class EvalUtilTest(tf.test.TestCase):
self.assertNotIn('DetectionMasks_Precision/mAP', metrics)
def test_get_eval_metric_ops_for_coco_detections_and_masks(self):
evaluation_metrics = ['coco_detection_metrics',
'coco_mask_metrics']
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(
['coco_detection_metrics', 'coco_mask_metrics'])
categories = self._get_categories_list()
eval_dict = self._make_evaluation_dict()
metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
evaluation_metrics, categories, eval_dict)
eval_config, categories, eval_dict)
_, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP']
_, update_op_masks = metric_ops['DetectionMasks_Precision/mAP']
......@@ -102,12 +105,13 @@ class EvalUtilTest(tf.test.TestCase):
self.assertAlmostEqual(1.0, metrics['DetectionMasks_Precision/mAP'])
def test_get_eval_metric_ops_for_coco_detections_and_resized_masks(self):
evaluation_metrics = ['coco_detection_metrics',
'coco_mask_metrics']
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(
['coco_detection_metrics', 'coco_mask_metrics'])
categories = self._get_categories_list()
eval_dict = self._make_evaluation_dict(resized_groundtruth_masks=True)
metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
evaluation_metrics, categories, eval_dict)
eval_config, categories, eval_dict)
_, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP']
_, update_op_masks = metric_ops['DetectionMasks_Precision/mAP']
......@@ -122,13 +126,53 @@ class EvalUtilTest(tf.test.TestCase):
self.assertAlmostEqual(1.0, metrics['DetectionMasks_Precision/mAP'])
def test_get_eval_metric_ops_raises_error_with_unsupported_metric(self):
evaluation_metrics = ['unsupported_metrics']
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(['unsupported_metric'])
categories = self._get_categories_list()
eval_dict = self._make_evaluation_dict()
with self.assertRaises(ValueError):
eval_util.get_eval_metric_ops_for_evaluators(
evaluation_metrics, categories, eval_dict)
eval_config, categories, eval_dict)
def test_get_eval_metric_ops_for_evaluators(self):
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(
['coco_detection_metrics', 'coco_mask_metrics'])
eval_config.include_metrics_per_category = True
evaluator_options = eval_util.evaluator_options_from_eval_config(
eval_config)
self.assertTrue(evaluator_options['coco_detection_metrics'][
'include_metrics_per_category'])
self.assertTrue(evaluator_options['coco_mask_metrics'][
'include_metrics_per_category'])
def test_get_evaluator_with_evaluator_options(self):
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(['coco_detection_metrics'])
eval_config.include_metrics_per_category = True
categories = self._get_categories_list()
evaluator_options = eval_util.evaluator_options_from_eval_config(
eval_config)
evaluator = eval_util.get_evaluators(
eval_config, categories, evaluator_options)
self.assertTrue(evaluator[0]._include_metrics_per_category)
def test_get_evaluator_with_no_evaluator_options(self):
eval_config = eval_pb2.EvalConfig()
eval_config.metrics_set.extend(['coco_detection_metrics'])
eval_config.include_metrics_per_category = True
categories = self._get_categories_list()
evaluator = eval_util.get_evaluators(
eval_config, categories, evaluator_options=None)
# Even though we are setting eval_config.include_metrics_per_category = True
# this option is never passed into the DetectionEvaluator constructor (via
# `evaluator_options`).
self.assertFalse(evaluator[0]._include_metrics_per_category)
if __name__ == '__main__':
tf.test.main()
......@@ -21,6 +21,7 @@ import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from object_detection import exporter
......@@ -95,6 +96,12 @@ def append_postprocessing_op(frozen_graph_def, max_detections,
new_output.name = 'TFLite_Detection_PostProcess'
new_output.attr['_output_quantized'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['_output_types'].list.type.extend([
types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
types_pb2.DT_FLOAT
])
new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['max_detections'].CopyFrom(
attr_value_pb2.AttrValue(i=max_detections))
new_output.attr['max_classes_per_detection'].CopyFrom(
......
......@@ -21,6 +21,7 @@ import os
import numpy as np
import six
import tensorflow as tf
from tensorflow.core.framework import types_pb2
from object_detection import export_tflite_ssd_graph_lib
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
......@@ -29,6 +30,7 @@ from object_detection.protos import graph_rewriter_pb2
from object_detection.protos import pipeline_pb2
from object_detection.protos import post_processing_pb2
if six.PY2:
import mock # pylint: disable=g-import-not-at-top
else:
......@@ -122,7 +124,7 @@ class ExportTfliteGraphTest(tf.test.TestCase):
return box_encodings_np, class_predictions_np
def _export_graph(self, pipeline_config, num_channels=3):
"""Exports a tflite graph and an anchor file."""
"""Exports a tflite graph."""
output_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt')
tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb')
......@@ -147,6 +149,34 @@ class ExportTfliteGraphTest(tf.test.TestCase):
max_classes_per_detection=1)
return tflite_graph_file
def _export_graph_with_postprocessing_op(self,
pipeline_config,
num_channels=3):
"""Exports a tflite graph with custom postprocessing op."""
output_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(output_dir, 'model.ckpt')
tflite_graph_file = os.path.join(output_dir, 'tflite_graph.pb')
quantize = pipeline_config.HasField('graph_rewriter')
self._save_checkpoint_from_mock_model(
trained_checkpoint_prefix,
use_moving_averages=pipeline_config.eval_config.use_moving_averages,
quantize=quantize,
num_channels=num_channels)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
with tf.Graph().as_default():
export_tflite_ssd_graph_lib.export_tflite_graph(
pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix,
output_dir=output_dir,
add_postprocessing_op=True,
max_detections=10,
max_classes_per_detection=1)
return tflite_graph_file
def test_export_tflite_graph_with_moving_averages(self):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = True
......@@ -267,6 +297,44 @@ class ExportTfliteGraphTest(tf.test.TestCase):
self.assertAllClose(class_predictions_np,
[[[0.668188, 0.645656], [0.710949, 0.5]]])
def test_export_tflite_graph_with_postprocessing_op(self):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
pipeline_config.model.ssd.post_processing.score_converter = (
post_processing_pb2.PostProcessing.SIGMOID)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10
pipeline_config.model.ssd.num_classes = 2
pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale = 10.0
pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale = 10.0
pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale = 5.0
pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale = 5.0
tflite_graph_file = self._export_graph_with_postprocessing_op(
pipeline_config)
self.assertTrue(os.path.exists(tflite_graph_file))
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.Open(tflite_graph_file) as f:
graph_def.ParseFromString(f.read())
all_op_names = [node.name for node in graph_def.node]
self.assertTrue('TFLite_Detection_PostProcess' in all_op_names)
for node in graph_def.node:
if node.name == 'TFLite_Detection_PostProcess':
self.assertTrue(node.attr['_output_quantized'].b is True)
self.assertTrue(
node.attr['_support_output_type_float_in_quantized_op'].b is True)
self.assertTrue(node.attr['y_scale'].f == 10.0)
self.assertTrue(node.attr['x_scale'].f == 10.0)
self.assertTrue(node.attr['h_scale'].f == 5.0)
self.assertTrue(node.attr['w_scale'].f == 5.0)
self.assertTrue(node.attr['num_classes'].i == 2)
self.assertTrue(
all([
t == types_pb2.DT_FLOAT
for t in node.attr['_output_types'].list.type
]))
if __name__ == '__main__':
tf.test.main()
# Frequently Asked Questions
## Q: How can I ensure that all the groundtruth boxes are used during train and eval?
A: For the object detecion framework to be TPU-complient, we must pad our input
tensors to static shapes. This means that we must pad to a fixed number of
bounding boxes, configured by `InputReader.max_number_of_boxes`. It is
important to set this value to a number larger than the maximum number of
groundtruth boxes in the dataset. If an image is encountered with more
bounding boxes, the excess boxes will be clipped.
## Q: AttributeError: 'module' object has no attribute 'BackupHandler'
A: This BackupHandler (tf.contrib.slim.tfexample_decoder.BackupHandler) was
introduced in tensorflow 1.5.0 so runing with earlier versions may cause this
......
......@@ -11,7 +11,7 @@ Tensorflow Object Detection API depends on the following libraries:
* tf Slim (which is included in the "tensorflow/models/research/" checkout)
* Jupyter notebook
* Matplotlib
* Tensorflow
* Tensorflow (>=1.9.0)
* Cython
* contextlib2
* cocoapi
......
......@@ -44,7 +44,7 @@ job using GPUs. A sample YAML file is given below:
```
trainingInput:
runtimeVersion: "1.8"
runtimeVersion: "1.9"
scaleTier: CUSTOM
masterType: standard_gpu
workerCount: 9
......@@ -73,7 +73,7 @@ following command:
```bash
# From tensorflow/models/research/
gcloud ml-engine jobs submit training object_detection_`date +%m_%d_%Y_%H_%M_%S` \
--runtime-version 1.8 \
--runtime-version 1.9 \
--job-dir=gs://${MODEL_DIR} \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_main \
......@@ -93,7 +93,7 @@ Google Cloud Storage.
Users can monitor the progress of their training job on the [ML Engine
Dashboard](https://console.cloud.google.com/mlengine/jobs).
Note: This sample is supported for use with 1.8 runtime version.
Note: This sample is supported for use with 1.9 runtime version.
## Running a TPU Training Job on CMLE
......@@ -105,7 +105,7 @@ gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%m_%d_%Y_
--job-dir=gs://${MODEL_DIR} \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_tpu_main \
--runtime-version 1.8 \
--runtime-version 1.9 \
--scale-tier BASIC_TPU \
--region us-central1 \
-- \
......@@ -133,7 +133,7 @@ job:
```bash
gcloud ml-engine jobs submit training object_detection_eval_`date +%m_%d_%Y_%H_%M_%S` \
--runtime-version 1.8 \
--runtime-version 1.9 \
--job-dir=gs://${MODEL_DIR} \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_main \
......
......@@ -221,6 +221,14 @@ def pad_input_data_to_static_shapes(tensor_dict, max_num_boxes, num_classes,
for tensor_name in tensor_dict:
padded_tensor_dict[tensor_name] = shape_utils.pad_or_clip_nd(
tensor_dict[tensor_name], padding_shapes[tensor_name])
# Make sure that the number of groundtruth boxes now reflects the
# padded/clipped tensors.
if fields.InputDataFields.num_groundtruth_boxes in padded_tensor_dict:
padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = (
tf.minimum(
padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
max_num_boxes))
return padded_tensor_dict
......
......@@ -663,6 +663,8 @@ class PadInputDataToStaticShapesFnTest(tf.test.TestCase):
tf.placeholder(tf.float32, [None, 4]),
fields.InputDataFields.groundtruth_classes:
tf.placeholder(tf.int32, [None, 3]),
fields.InputDataFields.num_groundtruth_boxes:
tf.placeholder(tf.int32, [])
}
padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
tensor_dict=input_tensor_dict,
......@@ -685,6 +687,8 @@ class PadInputDataToStaticShapesFnTest(tf.test.TestCase):
np.random.rand(5, 4),
input_tensor_dict[fields.InputDataFields.groundtruth_classes]:
np.random.rand(2, 3),
input_tensor_dict[fields.InputDataFields.num_groundtruth_boxes]:
5,
})
self.assertAllEqual(
......@@ -692,6 +696,9 @@ class PadInputDataToStaticShapesFnTest(tf.test.TestCase):
self.assertAllEqual(
out_tensor_dict[fields.InputDataFields.groundtruth_classes].shape,
[3, 3])
self.assertEqual(
out_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
3)
def test_do_not_pad_dynamic_images(self):
input_tensor_dict = {
......
......@@ -172,7 +172,8 @@ def _create_losses(input_queue, create_model_fn, train_config):
"""
detection_model = create_model_fn()
(images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
groundtruth_masks_list, groundtruth_keypoints_list,
groundtruth_weights_list) = get_inputs(
input_queue,
detection_model.num_classes,
train_config.merge_multiple_label_boxes,
......@@ -193,10 +194,12 @@ def _create_losses(input_queue, create_model_fn, train_config):
if any(keypoints is None for keypoints in groundtruth_keypoints_list):
groundtruth_keypoints_list = None
detection_model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list,
groundtruth_keypoints_list)
detection_model.provide_groundtruth(
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list,
groundtruth_keypoints_list,
groundtruth_weights_list=groundtruth_weights_list)
prediction_dict = detection_model.predict(images, true_image_shapes)
losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
......
......@@ -97,6 +97,7 @@ from functools import partial
import tensorflow as tf
from object_detection.anchor_generators import grid_anchor_generator
from object_detection.builders import box_predictor_builder
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import box_predictor
......@@ -105,7 +106,6 @@ from object_detection.core import model
from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.predictors import convolutional_box_predictor
from object_detection.utils import ops
from object_detection.utils import shape_utils
......@@ -413,17 +413,17 @@ class FasterRCNNMetaArch(model.DetectionModel):
self._first_stage_minibatch_size = first_stage_minibatch_size
self._first_stage_sampler = first_stage_sampler
self._first_stage_box_predictor = (
convolutional_box_predictor.ConvolutionalBoxPredictor(
self._is_training,
box_predictor_builder.build_convolutional_box_predictor(
is_training=self._is_training,
num_classes=1,
conv_hyperparams_fn=self._first_stage_box_predictor_arg_scope_fn,
min_depth=0,
max_depth=0,
num_layers_before_predictor=0,
use_dropout=False,
dropout_keep_prob=1.0,
box_code_size=self._box_coder.code_size,
kernel_size=1,
box_code_size=self._box_coder.code_size))
num_layers_before_predictor=0,
min_depth=0,
max_depth=0))
self._first_stage_nms_score_threshold = first_stage_nms_score_threshold
self._first_stage_nms_iou_threshold = first_stage_nms_iou_threshold
......@@ -1236,11 +1236,13 @@ class FasterRCNNMetaArch(model.DetectionModel):
proposal_boxes = tf.stop_gradient(proposal_boxes)
if not self._hard_example_miner:
(groundtruth_boxlists, groundtruth_classes_with_background_list, _,
_) = self._format_groundtruth_data(true_image_shapes)
groundtruth_weights_list
) = self._format_groundtruth_data(true_image_shapes)
(proposal_boxes, proposal_scores,
num_proposals) = self._unpad_proposals_and_sample_box_classifier_batch(
num_proposals) = self._sample_box_classifier_batch(
proposal_boxes, proposal_scores, num_proposals,
groundtruth_boxlists, groundtruth_classes_with_background_list)
groundtruth_boxlists, groundtruth_classes_with_background_list,
groundtruth_weights_list)
# normalize proposal boxes
def normalize_boxes(args):
proposal_boxes_per_image = args[0]
......@@ -1253,14 +1255,15 @@ class FasterRCNNMetaArch(model.DetectionModel):
normalize_boxes, elems=[proposal_boxes, image_shapes], dtype=tf.float32)
return normalized_proposal_boxes, proposal_scores, num_proposals
def _unpad_proposals_and_sample_box_classifier_batch(
def _sample_box_classifier_batch(
self,
proposal_boxes,
proposal_scores,
num_proposals,
groundtruth_boxlists,
groundtruth_classes_with_background_list):
"""Unpads proposals and samples a minibatch for second stage.
groundtruth_classes_with_background_list,
groundtruth_weights_list):
"""Samples a minibatch for second stage.
Args:
proposal_boxes: A float tensor with shape
......@@ -1278,6 +1281,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
groundtruth_classes_with_background_list: A list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes+1] containing the
class targets with the 0th index assumed to map to the background class.
groundtruth_weights_list: A list of 1-D tensors of shape [num_boxes]
indicating the weight associated with the groundtruth boxes.
Returns:
proposal_boxes: A float tensor with shape
......@@ -1298,31 +1303,23 @@ class FasterRCNNMetaArch(model.DetectionModel):
single_image_proposal_scores,
single_image_num_proposals,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background) in zip(
single_image_groundtruth_classes_with_background,
single_image_groundtruth_weights) in zip(
tf.unstack(proposal_boxes),
tf.unstack(proposal_scores),
tf.unstack(num_proposals),
groundtruth_boxlists,
groundtruth_classes_with_background_list):
static_shape = single_image_proposal_boxes.get_shape()
sliced_static_shape = tf.TensorShape([tf.Dimension(None),
static_shape.dims[-1]])
single_image_proposal_boxes = tf.slice(
single_image_proposal_boxes,
[0, 0],
[single_image_num_proposals, -1])
single_image_proposal_boxes.set_shape(sliced_static_shape)
single_image_proposal_scores = tf.slice(single_image_proposal_scores,
[0],
[single_image_num_proposals])
groundtruth_classes_with_background_list,
groundtruth_weights_list):
single_image_boxlist = box_list.BoxList(single_image_proposal_boxes)
single_image_boxlist.add_field(fields.BoxListFields.scores,
single_image_proposal_scores)
sampled_boxlist = self._sample_box_classifier_minibatch(
sampled_boxlist = self._sample_box_classifier_minibatch_single_image(
single_image_boxlist,
single_image_num_proposals,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background)
single_image_groundtruth_classes_with_background,
single_image_groundtruth_weights)
sampled_padded_boxlist = box_list_ops.pad_or_clip_box_list(
sampled_boxlist,
num_boxes=self._second_stage_batch_size)
......@@ -1394,18 +1391,23 @@ class FasterRCNNMetaArch(model.DetectionModel):
resized_masks_list.append(resized_mask)
groundtruth_masks_list = resized_masks_list
groundtruth_weights_list = None
if self.groundtruth_has_field(fields.BoxListFields.weights):
groundtruth_weights_list = self.groundtruth_lists(
fields.BoxListFields.weights)
else:
# Set weights for all batch elements equally to 1.0
groundtruth_weights_list = []
for groundtruth_classes in groundtruth_classes_with_background_list:
num_gt = tf.shape(groundtruth_classes)[0]
groundtruth_weights = tf.ones(num_gt)
groundtruth_weights_list.append(groundtruth_weights)
return (groundtruth_boxlists, groundtruth_classes_with_background_list,
groundtruth_masks_list, groundtruth_weights_list)
def _sample_box_classifier_minibatch(self,
proposal_boxlist,
groundtruth_boxlist,
groundtruth_classes_with_background):
def _sample_box_classifier_minibatch_single_image(
self, proposal_boxlist, num_valid_proposals, groundtruth_boxlist,
groundtruth_classes_with_background, groundtruth_weights):
"""Samples a mini-batch of proposals to be sent to the box classifier.
Helper function for self._postprocess_rpn.
......@@ -1413,12 +1415,14 @@ class FasterRCNNMetaArch(model.DetectionModel):
Args:
proposal_boxlist: A BoxList containing K proposal boxes in absolute
coordinates.
num_valid_proposals: Number of valid proposals in the proposal boxlist.
groundtruth_boxlist: A Boxlist containing N groundtruth object boxes in
absolute coordinates.
groundtruth_classes_with_background: A tensor with shape
`[N, self.num_classes + 1]` representing groundtruth classes. The
classes are assumed to be k-hot encoded, and include background as the
zero-th class.
groundtruth_weights: Weights attached to the groundtruth_boxes.
Returns:
a BoxList contained sampled proposals.
......@@ -1428,15 +1432,19 @@ class FasterRCNNMetaArch(model.DetectionModel):
groundtruth_boxlist,
groundtruth_classes_with_background,
unmatched_class_label=tf.constant(
[1] + self._num_classes * [0], dtype=tf.float32))
[1] + self._num_classes * [0], dtype=tf.float32),
groundtruth_weights=groundtruth_weights)
# Selects all boxes as candidates if none of them is selected according
# to cls_weights. This could happen as boxes within certain IOU ranges
# are ignored. If triggered, the selected boxes will still be ignored
# during loss computation.
cls_weights += tf.to_float(tf.equal(tf.reduce_sum(cls_weights), 0))
positive_indicator = tf.greater(tf.argmax(cls_targets, axis=1), 0)
valid_indicator = tf.logical_and(
tf.range(proposal_boxlist.num_boxes()) < num_valid_proposals,
cls_weights > 0
)
sampled_indices = self._second_stage_sampler.subsample(
tf.cast(cls_weights, tf.bool),
valid_indicator,
self._second_stage_batch_size,
positive_indicator)
return box_list_ops.boolean_mask(proposal_boxlist, sampled_indices)
......@@ -1704,9 +1712,10 @@ class FasterRCNNMetaArch(model.DetectionModel):
with tf.name_scope('RPNLoss'):
(batch_cls_targets, batch_cls_weights, batch_reg_targets,
batch_reg_weights, _) = target_assigner.batch_assign_targets(
self._proposal_target_assigner, box_list.BoxList(anchors),
groundtruth_boxlists,
len(groundtruth_boxlists) * [None],
target_assigner=self._proposal_target_assigner,
anchors_batch=box_list.BoxList(anchors),
gt_box_batch=groundtruth_boxlists,
gt_class_targets_batch=(len(groundtruth_boxlists) * [None]),
gt_weights_batch=groundtruth_weights_list)
batch_cls_targets = tf.squeeze(batch_cls_targets, axis=2)
......@@ -1827,10 +1836,10 @@ class FasterRCNNMetaArch(model.DetectionModel):
(batch_cls_targets_with_background, batch_cls_weights, batch_reg_targets,
batch_reg_weights, _) = target_assigner.batch_assign_targets(
self._detector_target_assigner,
proposal_boxlists,
groundtruth_boxlists,
groundtruth_classes_with_background_list,
target_assigner=self._detector_target_assigner,
anchors_batch=proposal_boxlists,
gt_box_batch=groundtruth_boxlists,
gt_class_targets_batch=groundtruth_classes_with_background_list,
unmatched_class_label=tf.constant(
[1] + self._num_classes * [0], dtype=tf.float32),
gt_weights_batch=groundtruth_weights_list)
......@@ -1908,9 +1917,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
unmatched_mask_label = tf.zeros(image_shape[1:3], dtype=tf.float32)
(batch_mask_targets, _, _, batch_mask_target_weights,
_) = target_assigner.batch_assign_targets(
self._detector_target_assigner, proposal_boxlists,
groundtruth_boxlists, groundtruth_masks_list, unmatched_mask_label,
groundtruth_weights_list)
target_assigner=self._detector_target_assigner,
anchors_batch=proposal_boxlists,
gt_box_batch=groundtruth_boxlists,
gt_class_targets_batch=groundtruth_masks_list,
unmatched_class_label=unmatched_mask_label,
gt_weights_batch=groundtruth_weights_list)
# Pad the prediction_masks with to add zeros for background class to be
# consistent with class predictions.
......
......@@ -230,9 +230,14 @@ class FasterRCNNMetaArchTest(
tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)
]
groundtruth_weights_list = [
tf.constant([1, 1], dtype=tf.float32),
tf.constant([1, 1], dtype=tf.float32)]
_, true_image_shapes = model.preprocess(tf.zeros(image_shape))
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
model.provide_groundtruth(
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_weights_list=groundtruth_weights_list)
result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes)
mask_shape_1 = 1 if masks_are_class_agnostic else model._num_classes
......
......@@ -511,10 +511,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
groundtruth_classes_list = [
tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
groundtruth_weights_list = [
tf.constant([1, 1], dtype=tf.float32),
tf.constant([1, 1], dtype=tf.float32)]
_, true_image_shapes = model.preprocess(tf.zeros(image_shape))
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
model.provide_groundtruth(
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_weights_list=groundtruth_weights_list)
result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes)
expected_shapes = {
......@@ -663,10 +667,15 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
tf.constant([[0, .5, .5, 1], [.5, 0, 1, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
groundtruth_weights_list = [
tf.constant([1, 1], dtype=tf.float32),
tf.constant([1, 1], dtype=tf.float32)
]
_, true_image_shapes = model.preprocess(tf.zeros(image_shape))
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
model.provide_groundtruth(
groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_weights_list=groundtruth_weights_list)
proposals = model.postprocess({
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
......
......@@ -243,7 +243,9 @@ class SSDMetaArch(model.DetectionModel):
freeze_batchnorm=False,
inplace_batchnorm_update=False,
add_background_class=True,
random_example_sampler=None):
random_example_sampler=None,
expected_classification_loss_under_sampling=None,
target_assigner_instance=None):
"""SSDMetaArch Constructor.
TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
......@@ -308,6 +310,9 @@ class SSDMetaArch(model.DetectionModel):
example miner can both be applied to the model. In that case, random
sampler will take effect first and hard example miner can only process
the random sampled examples.
expected_classification_loss_under_sampling: If not None, use
to calcualte classification loss by background/foreground weighting.
target_assigner_instance: target_assigner.TargetAssigner instance to use.
"""
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training
......@@ -342,11 +347,14 @@ class SSDMetaArch(model.DetectionModel):
self._unmatched_class_label = tf.constant((self.num_classes + 1) * [0],
tf.float32)
self._target_assigner = target_assigner.TargetAssigner(
self._region_similarity_calculator,
self._matcher,
self._box_coder,
negative_class_weight=negative_class_weight)
if target_assigner_instance:
self._target_assigner = target_assigner_instance
else:
self._target_assigner = target_assigner.TargetAssigner(
self._region_similarity_calculator,
self._matcher,
self._box_coder,
negative_class_weight=negative_class_weight)
self._classification_loss = classification_loss
self._localization_loss = localization_loss
......@@ -365,6 +373,8 @@ class SSDMetaArch(model.DetectionModel):
self._anchors = None
self._add_summaries = add_summaries
self._batched_prediction_tensor_names = []
self._expected_classification_loss_under_sampling = (
expected_classification_loss_under_sampling)
@property
def anchors(self):
......@@ -696,19 +706,34 @@ class SSDMetaArch(model.DetectionModel):
batch_reg_targets,
ignore_nan_targets=True,
weights=batch_reg_weights)
cls_losses = ops.reduce_sum_trailing_dimensions(
self._classification_loss(
prediction_dict['class_predictions_with_background'],
batch_cls_targets,
weights=batch_cls_weights),
ndims=2)
if self._hard_example_miner:
cls_losses = self._classification_loss(
prediction_dict['class_predictions_with_background'],
batch_cls_targets,
weights=batch_cls_weights)
if self._expected_classification_loss_under_sampling:
if cls_losses.get_shape().ndims == 3:
batch_size, num_anchors, num_classes = cls_losses.get_shape()
cls_losses = tf.reshape(cls_losses, [batch_size, -1])
batch_cls_targets = tf.reshape(
batch_cls_targets, [batch_size, num_anchors * num_classes, -1])
batch_cls_targets = tf.concat(
[1 - batch_cls_targets, batch_cls_targets], axis=-1)
cls_losses = self._expected_classification_loss_under_sampling(
batch_cls_targets, cls_losses)
classification_loss = tf.reduce_sum(cls_losses)
localization_loss = tf.reduce_sum(location_losses)
elif self._hard_example_miner:
cls_losses = ops.reduce_sum_trailing_dimensions(cls_losses, ndims=2)
(localization_loss, classification_loss) = self._apply_hard_mining(
location_losses, cls_losses, prediction_dict, match_list)
if self._add_summaries:
self._hard_example_miner.summarize()
else:
cls_losses = ops.reduce_sum_trailing_dimensions(cls_losses, ndims=2)
if self._add_summaries:
class_ids = tf.argmax(batch_cls_targets, axis=2)
flattened_class_ids = tf.reshape(class_ids, [-1])
......@@ -993,4 +1018,3 @@ class SSDMetaArch(model.DetectionModel):
variables_to_restore[var_name] = variable
return variables_to_restore
......@@ -26,7 +26,9 @@ from object_detection.core import box_list
from object_detection.core import losses
from object_detection.core import post_processing
from object_detection.core import region_similarity_calculator as sim_calc
from object_detection.core import target_assigner
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.utils import ops
from object_detection.utils import test_case
from object_detection.utils import test_utils
......@@ -117,6 +119,10 @@ class SsdMetaArchTest(test_case.TestCase, parameterized.TestCase):
normalize_loc_loss_by_codesize=False,
add_background_class=True,
random_example_sampling=False,
weight_regression_loss_by_score=False,
use_expected_classification_loss_under_sampling=False,
minimum_negative_sampling=1,
desired_negative_sampling_ratio=3,
use_keras=False):
is_training = False
num_classes = 1
......@@ -163,6 +169,20 @@ class SsdMetaArchTest(test_case.TestCase, parameterized.TestCase):
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=0.5)
target_assigner_instance = target_assigner.TargetAssigner(
region_similarity_calculator,
mock_matcher,
mock_box_coder,
negative_class_weight=negative_class_weight,
weight_regression_loss_by_score=weight_regression_loss_by_score)
expected_classification_loss_under_sampling = None
if use_expected_classification_loss_under_sampling:
expected_classification_loss_under_sampling = functools.partial(
ops.expected_classification_loss_under_sampling,
minimum_negative_sampling=minimum_negative_sampling,
desired_negative_sampling_ratio=desired_negative_sampling_ratio)
code_size = 4
model = ssd_meta_arch.SSDMetaArch(
is_training,
......@@ -183,12 +203,15 @@ class SsdMetaArchTest(test_case.TestCase, parameterized.TestCase):
localization_loss_weight,
normalize_loss_by_num_matches,
hard_example_miner,
target_assigner_instance=target_assigner_instance,
add_summaries=False,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=False,
inplace_batchnorm_update=False,
add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
random_example_sampler=random_example_sampler,
expected_classification_loss_under_sampling=
expected_classification_loss_under_sampling)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size
def test_preprocess_preserves_shapes_with_dynamic_input_image(
......@@ -470,6 +493,94 @@ class SsdMetaArchTest(test_case.TestCase, parameterized.TestCase):
groundtruth_classes1 = np.array([[0, 1]], dtype=np.float32)
groundtruth_classes2 = np.array([[0, 1]], dtype=np.float32)
expected_localization_loss = 0.0
expected_classification_loss = (
batch_size * num_anchors * (num_classes + 1) * np.log(2.0))
(localization_loss, classification_loss) = self.execute(
graph_fn, [
preprocessed_input, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2
])
self.assertAllClose(localization_loss, expected_localization_loss)
self.assertAllClose(classification_loss, expected_classification_loss)
def test_loss_with_expected_classification_loss(self, use_keras):
with tf.Graph().as_default():
_, num_classes, num_anchors, _ = self._create_model(use_keras=use_keras)
def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2):
groundtruth_boxes_list = [groundtruth_boxes1, groundtruth_boxes2]
groundtruth_classes_list = [groundtruth_classes1, groundtruth_classes2]
model, _, _, _ = self._create_model(
apply_hard_mining=False,
add_background_class=True,
use_expected_classification_loss_under_sampling=True,
minimum_negative_sampling=1,
desired_negative_sampling_ratio=desired_negative_sampling_ratio)
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
prediction_dict = model.predict(
preprocessed_tensor, true_image_shapes=None)
loss_dict = model.loss(prediction_dict, true_image_shapes=None)
return (loss_dict['Loss/localization_loss'],
loss_dict['Loss/classification_loss'])
batch_size = 2
desired_negative_sampling_ratio = 4
preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32)
groundtruth_boxes1 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_boxes2 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_classes1 = np.array([[1]], dtype=np.float32)
groundtruth_classes2 = np.array([[1]], dtype=np.float32)
expected_localization_loss = 0.0
expected_classification_loss = (
batch_size * (desired_negative_sampling_ratio * num_anchors +
num_classes * num_anchors) * np.log(2.0))
(localization_loss, classification_loss) = self.execute(
graph_fn, [
preprocessed_input, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2
])
self.assertAllClose(localization_loss, expected_localization_loss)
self.assertAllClose(classification_loss, expected_classification_loss)
def test_loss_results_are_correct_with_weight_regression_loss_by_score(
self, use_keras):
with tf.Graph().as_default():
_, num_classes, num_anchors, _ = self._create_model(
use_keras=use_keras,
add_background_class=False,
weight_regression_loss_by_score=True)
def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2):
groundtruth_boxes_list = [groundtruth_boxes1, groundtruth_boxes2]
groundtruth_classes_list = [groundtruth_classes1, groundtruth_classes2]
model, _, _, _ = self._create_model(
use_keras=use_keras,
apply_hard_mining=False,
add_background_class=False,
weight_regression_loss_by_score=True)
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
prediction_dict = model.predict(
preprocessed_tensor, true_image_shapes=None)
loss_dict = model.loss(prediction_dict, true_image_shapes=None)
return (loss_dict['Loss/localization_loss'],
loss_dict['Loss/classification_loss'])
batch_size = 2
preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32)
groundtruth_boxes1 = np.array([[0, 0, 1, 1]], dtype=np.float32)
groundtruth_boxes2 = np.array([[0, 0, 1, 1]], dtype=np.float32)
groundtruth_classes1 = np.array([[0, 1]], dtype=np.float32)
groundtruth_classes2 = np.array([[1, 0]], dtype=np.float32)
expected_localization_loss = 0.25
expected_classification_loss = (
batch_size * num_anchors * (num_classes + 1) * np.log(2.0))
(localization_loss, classification_loss) = self.execute(
......
......@@ -201,14 +201,8 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
for key, value in iter(box_metrics.items())}
return box_metrics
def get_estimator_eval_metric_ops(self, image_id, groundtruth_boxes,
groundtruth_classes,
detection_boxes,
detection_scores, detection_classes,
groundtruth_is_crowd=None,
num_gt_boxes_per_image=None,
num_det_boxes_per_image=None):
"""Returns a dictionary of eval metric ops to use with `tf.EstimatorSpec`.
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns a dictionary of eval metric ops.
Note that once value_op is called, the detections and groundtruth added via
update_op are cleared.
......@@ -218,35 +212,18 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
tensors need not be present.
Args:
image_id: string/integer tensor of shape [batch] with unique identifiers
for the images.
groundtruth_boxes: float32 tensor of shape [batch, num_boxes, 4]
containing `num_boxes` groundtruth boxes of the format
[ymin, xmin, ymax, xmax] in absolute image coordinates.
groundtruth_classes: int32 tensor of shape [batch, num_boxes] containing
1-indexed groundtruth classes for the boxes.
detection_boxes: float32 tensor of shape [batch, num_boxes, 4] containing
`num_boxes` detection boxes of the format [ymin, xmin, ymax, xmax]
in absolute image coordinates.
detection_scores: float32 tensor of shape [batch, num_boxes] containing
detection scores for the boxes.
detection_classes: int32 tensor of shape [batch, num_boxes] containing
1-indexed detection classes for the boxes.
groundtruth_is_crowd: bool tensor of shape [batch, num_boxes] containing
is_crowd annotations. This field is optional, and if not passed, then
all boxes are treated as *not* is_crowd.
num_gt_boxes_per_image: int32 tensor of shape [batch] containing the
number of groundtruth boxes per image. If None, will assume no padding
in groundtruth tensors.
num_det_boxes_per_image: int32 tensor of shape [batch] containing the
number of detection boxes per image. If None, will assume no padding in
the detection tensors.
eval_dict: A dictionary that holds tensors for evaluating object detection
performance. For single-image evaluation, this dictionary may be
produced from eval_util.result_dict_for_single_example(). If multi-image
evaluation, `eval_dict` should contain the fields
'num_groundtruth_boxes_per_image' and 'num_det_boxes_per_image' to
properly unpad the tensors from the batch.
Returns:
a dictionary of metric names to tuple of value_op and update_op that can
be used as eval metric ops in tf.EstimatorSpec. Note that all update ops
must be run together and similarly all value ops must be run together to
guarantee correct behaviour.
be used as eval metric ops in tf.estimator.EstimatorSpec. Note that all
update ops must be run together and similarly all value ops must be run
together to guarantee correct behaviour.
"""
def update_op(
image_id_batched,
......@@ -278,6 +255,22 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
'detection_scores': det_score[:num_det_box],
'detection_classes': det_class[:num_det_box]})
# Unpack items from the evaluation dictionary.
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
image_id = eval_dict[input_data_fields.key]
groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes]
groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes]
groundtruth_is_crowd = eval_dict.get(
input_data_fields.groundtruth_is_crowd, None)
detection_boxes = eval_dict[detection_fields.detection_boxes]
detection_scores = eval_dict[detection_fields.detection_scores]
detection_classes = eval_dict[detection_fields.detection_classes]
num_gt_boxes_per_image = eval_dict.get(
'num_groundtruth_boxes_per_image', None)
num_det_boxes_per_image = eval_dict.get(
'num_groundtruth_boxes_per_image', None)
if groundtruth_is_crowd is None:
groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool)
if not image_id.shape.as_list():
......@@ -553,42 +546,22 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
for key, value in mask_metrics.iteritems()}
return mask_metrics
def get_estimator_eval_metric_ops(self, image_id, groundtruth_boxes,
groundtruth_classes,
groundtruth_instance_masks,
detection_scores, detection_classes,
detection_masks, groundtruth_is_crowd=None):
"""Returns a dictionary of eval metric ops to use with `tf.EstimatorSpec`.
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns a dictionary of eval metric ops.
Note that once value_op is called, the detections and groundtruth added via
update_op are cleared.
Args:
image_id: Unique string/integer identifier for the image.
groundtruth_boxes: float32 tensor of shape [num_boxes, 4] containing
`num_boxes` groundtruth boxes of the format
[ymin, xmin, ymax, xmax] in absolute image coordinates.
groundtruth_classes: int32 tensor of shape [num_boxes] containing
1-indexed groundtruth classes for the boxes.
groundtruth_instance_masks: uint8 tensor array of shape
[num_boxes, image_height, image_width] containing groundtruth masks
corresponding to the boxes. The elements of the array must be in {0, 1}.
detection_scores: float32 tensor of shape [num_boxes] containing
detection scores for the boxes.
detection_classes: int32 tensor of shape [num_boxes] containing
1-indexed detection classes for the boxes.
detection_masks: uint8 tensor array of shape
[num_boxes, image_height, image_width] containing instance masks
corresponding to the boxes. The elements of the array must be in {0, 1}.
groundtruth_is_crowd: bool tensor of shape [batch, num_boxes] containing
is_crowd annotations. This field is optional, and if not passed, then
all boxes are treated as *not* is_crowd.
eval_dict: A dictionary that holds tensors for evaluating object detection
performance. This dictionary may be produced from
eval_util.result_dict_for_single_example().
Returns:
a dictionary of metric names to tuple of value_op and update_op that can
be used as eval metric ops in tf.EstimatorSpec. Note that all update ops
must be run together and similarly all value ops must be run together to
guarantee correct behaviour.
be used as eval metric ops in tf.estimator.EstimatorSpec. Note that all
update ops must be run together and similarly all value ops must be run
together to guarantee correct behaviour.
"""
def update_op(
image_id,
......@@ -599,6 +572,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
detection_scores,
detection_classes,
detection_masks):
"""Update op for metrics."""
self.add_single_ground_truth_image_info(
image_id,
{'groundtruth_boxes': groundtruth_boxes,
......@@ -611,6 +585,20 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
'detection_classes': detection_classes,
'detection_masks': detection_masks})
# Unpack items from the evaluation dictionary.
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
image_id = eval_dict[input_data_fields.key]
groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes]
groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes]
groundtruth_instance_masks = eval_dict[
input_data_fields.groundtruth_instance_masks]
groundtruth_is_crowd = eval_dict.get(
input_data_fields.groundtruth_is_crowd, None)
detection_scores = eval_dict[detection_fields.detection_scores]
detection_classes = eval_dict[detection_fields.detection_classes]
detection_masks = eval_dict[detection_fields.detection_masks]
if groundtruth_is_crowd is None:
groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool)
update_op = tf.py_func(update_op, [image_id,
......
......@@ -258,12 +258,18 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_scores = tf.placeholder(tf.float32, shape=(None))
detection_classes = tf.placeholder(tf.float32, shape=(None))
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
image_id, groundtruth_boxes,
groundtruth_classes,
detection_boxes,
detection_scores,
detection_classes)
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes
}
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict)
_, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP']
......@@ -336,9 +342,18 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_scores = tf.placeholder(tf.float32, shape=(None))
detection_classes = tf.placeholder(tf.float32, shape=(None))
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
image_id, groundtruth_boxes, groundtruth_classes, detection_boxes,
detection_scores, detection_classes)
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes
}
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict)
_, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP']
......@@ -426,12 +441,18 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_scores = tf.placeholder(tf.float32, shape=(batch_size, None))
detection_classes = tf.placeholder(tf.float32, shape=(batch_size, None))
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
image_id, groundtruth_boxes,
groundtruth_classes,
detection_boxes,
detection_scores,
detection_classes)
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes
}
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict)
_, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP']
......@@ -486,14 +507,20 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_classes = tf.placeholder(tf.float32, shape=(batch_size, None))
num_det_boxes_per_image = tf.placeholder(tf.int32, shape=(None))
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
image_id, groundtruth_boxes,
groundtruth_classes,
detection_boxes,
detection_scores,
detection_classes,
num_gt_boxes_per_image=num_gt_boxes_per_image,
num_det_boxes_per_image=num_det_boxes_per_image)
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
detection_fields.detection_boxes: detection_boxes,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes,
'num_groundtruth_boxes_per_image': num_gt_boxes_per_image,
'num_det_boxes_per_image': num_det_boxes_per_image
}
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict)
_, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP']
......@@ -642,13 +669,19 @@ class CocoMaskEvaluationPyFuncTest(tf.test.TestCase):
detection_classes = tf.placeholder(tf.float32, shape=(None))
detection_masks = tf.placeholder(tf.uint8, shape=(None, None, None))
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
image_id, groundtruth_boxes,
groundtruth_classes,
groundtruth_masks,
detection_scores,
detection_classes,
detection_masks)
input_data_fields = standard_fields.InputDataFields
detection_fields = standard_fields.DetectionResultFields
eval_dict = {
input_data_fields.key: image_id,
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes,
input_data_fields.groundtruth_instance_masks: groundtruth_masks,
detection_fields.detection_scores: detection_scores,
detection_fields.detection_classes: detection_classes,
detection_fields.detection_masks: detection_masks,
}
eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict)
_, update_op = eval_metric_ops['DetectionMasks_Precision/mAP']
......
......@@ -234,6 +234,9 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
gt_weights_list = None
if fields.InputDataFields.groundtruth_weights in labels:
gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
if fields.InputDataFields.groundtruth_is_crowd in labels:
gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
detection_model.provide_groundtruth(
......@@ -241,8 +244,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
groundtruth_classes_list=gt_classes_list,
groundtruth_masks_list=gt_masks_list,
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_weights_list=labels[
fields.InputDataFields.groundtruth_weights],
groundtruth_weights_list=gt_weights_list,
groundtruth_is_crowd_list=gt_is_crowd_list)
preprocessed_images = features[fields.InputDataFields.image]
......@@ -313,10 +315,16 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
# Optionally freeze some layers by setting their gradients to be zero.
trainable_variables = None
if train_config.freeze_variables:
trainable_variables = tf.contrib.framework.filter_variables(
tf.trainable_variables(),
exclude_patterns=train_config.freeze_variables)
include_variables = (
train_config.update_trainable_variables
if train_config.update_trainable_variables else None)
exclude_variables = (
train_config.freeze_variables
if train_config.freeze_variables else None)
trainable_variables = tf.contrib.framework.filter_variables(
tf.trainable_variables(),
include_patterns=include_variables,
exclude_patterns=exclude_variables)
clip_gradients_value = None
if train_config.gradient_clipping_by_norm > 0:
......@@ -377,14 +385,10 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
detection_and_groundtruth)
# Eval metrics on a single example.
eval_metrics = eval_config.metrics_set
if not eval_metrics:
eval_metrics = ['coco_detection_metrics']
eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
eval_metrics,
eval_config,
category_index.values(),
eval_dict,
include_metrics_per_category=eval_config.include_metrics_per_category)
eval_dict)
for loss_key, loss_tensor in iter(losses_dict.items()):
eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
for var in optimizer_summary_vars:
......
......@@ -178,6 +178,31 @@ class ModelLibTest(tf.test.TestCase):
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_train_mode_freeze_all_variables(self):
"""Tests model_fn TRAIN mode with all variables frozen."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
configs['train_config'].freeze_variables.append('.*')
with self.assertRaisesRegexp(ValueError, 'No variables to optimize'):
self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_train_mode_freeze_all_included_variables(self):
"""Tests model_fn TRAIN mode with all included variables frozen."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
train_config = configs['train_config']
train_config.update_trainable_variables.append('FeatureExtractor')
train_config.freeze_variables.append('.*')
with self.assertRaisesRegexp(ValueError, 'No variables to optimize'):
self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_train_mode_freeze_box_predictor(self):
"""Tests model_fn TRAIN mode with FeatureExtractor variables frozen."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
train_config = configs['train_config']
train_config.update_trainable_variables.append('FeatureExtractor')
train_config.update_trainable_variables.append('BoxPredictor')
train_config.freeze_variables.append('FeatureExtractor')
self._assert_model_fn_for_train_eval(configs, 'train')
def test_model_fn_in_eval_mode(self):
"""Tests the model function in EVAL mode."""
configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
......
......@@ -18,6 +18,7 @@ import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.predictors import convolutional_box_predictor as box_predictor
from object_detection.protos import hyperparams_pb2
......@@ -44,18 +45,18 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
def test_get_boxes_for_five_aspect_ratios_per_location(self):
def graph_fn(image_features):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4
)
conv_box_predictor = (
box_predictor_builder.build_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......@@ -73,18 +74,18 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
def test_get_boxes_for_one_aspect_ratio_per_location(self):
def graph_fn(image_features):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4
)
conv_box_predictor = (
box_predictor_builder.build_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[1],
scope='BoxPredictor')
......@@ -104,18 +105,18 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
num_classes_without_background = 6
image_features = np.random.rand(4, 8, 8, 64).astype(np.float32)
def graph_fn(image_features):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4
)
conv_box_predictor = (
box_predictor_builder.build_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features],
num_predictions_per_location=[5],
......@@ -136,18 +137,18 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
def test_get_predictions_with_feature_maps_of_dynamic_shape(
self):
image_features = tf.placeholder(dtype=tf.float32, shape=[4, None, None, 64])
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4
)
conv_box_predictor = (
box_predictor_builder.build_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
use_dropout=True,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......@@ -183,19 +184,19 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
def test_use_depthwise_convolution(self):
image_features = tf.placeholder(dtype=tf.float32, shape=[4, None, None, 64])
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4,
use_dropout=True,
use_depthwise=True
)
conv_box_predictor = (
box_predictor_builder.build_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0,
max_depth=32,
num_layers_before_predictor=1,
dropout_keep_prob=0.8,
kernel_size=1,
box_code_size=4,
use_dropout=True,
use_depthwise=True))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......@@ -278,13 +279,14 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
def test_get_boxes_for_five_aspect_ratios_per_location(self):
def graph_fn(image_features):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......@@ -302,14 +304,15 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
def test_bias_predictions_to_background_with_sigmoid_score_conversion(self):
def graph_fn(image_features):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=True,
num_classes=2,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
class_prediction_bias_init=-4.6,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=True,
num_classes=2,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
class_prediction_bias_init=-4.6,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......@@ -325,13 +328,14 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
num_classes_without_background = 6
def graph_fn(image_features):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features],
num_predictions_per_location=[5],
......@@ -354,13 +358,14 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -385,13 +390,14 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2, image_features3):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2, image_features3],
num_predictions_per_location=[5, 5, 5],
......@@ -416,13 +422,14 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -482,14 +489,15 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
apply_batch_norm=False)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
apply_batch_norm=False))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -540,14 +548,15 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
def test_no_batchnorm_params_when_batchnorm_is_not_configured(self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_conv_arg_scope_no_batch_norm(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
apply_batch_norm=False)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_conv_arg_scope_no_batch_norm(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
apply_batch_norm=False))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -599,14 +608,15 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
share_prediction_tower=True)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
share_prediction_tower=True))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -653,15 +663,16 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
self):
num_classes_without_background = 6
def graph_fn(image_features1, image_features2):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
share_prediction_tower=True,
apply_batch_norm=False)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=num_classes_without_background,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=2,
box_code_size=4,
share_prediction_tower=True,
apply_batch_norm=False))
box_predictions = conv_box_predictor.predict(
[image_features1, image_features2],
num_predictions_per_location=[5, 5],
......@@ -698,18 +709,20 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
'ClassPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictor/biases')])
self.assertEqual(expected_variable_set, actual_variable_set)
def test_get_predictions_with_feature_maps_of_dynamic_shape(
self):
image_features = tf.placeholder(dtype=tf.float32, shape=[4, None, None, 64])
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4)
conv_box_predictor = (
box_predictor_builder.build_weight_shared_convolutional_box_predictor(
is_training=False,
num_classes=0,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
box_code_size=4))
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
......
......@@ -13,16 +13,25 @@
# limitations under the License.
# ==============================================================================
"""Mask R-CNN Box Head."""
"""Box Head.
Contains Box prediction head classes for different meta architectures.
All the box prediction heads have a predict function that receives the
`features` as the first argument and returns `box_encodings`.
"""
import tensorflow as tf
from object_detection.predictors.mask_rcnn_heads import mask_rcnn_head
from object_detection.predictors.heads import head
slim = tf.contrib.slim
class BoxHead(mask_rcnn_head.MaskRCNNHead):
"""Mask RCNN box prediction head."""
class MaskRCNNBoxHead(head.Head):
"""Box prediction head.
Please refer to Mask RCNN paper:
https://arxiv.org/abs/1703.06870
"""
def __init__(self,
is_training,
......@@ -51,7 +60,7 @@ class BoxHead(mask_rcnn_head.MaskRCNNHead):
share_box_across_classes: Whether to share boxes across classes rather
than use a different box for each class.
"""
super(BoxHead, self).__init__()
super(MaskRCNNBoxHead, self).__init__()
self._is_training = is_training
self._num_classes = num_classes
self._fc_hyperparams_fn = fc_hyperparams_fn
......@@ -60,20 +69,27 @@ class BoxHead(mask_rcnn_head.MaskRCNNHead):
self._box_code_size = box_code_size
self._share_box_across_classes = share_box_across_classes
def _predict(self, roi_pooled_features):
def predict(self, features, num_predictions_per_location=1):
"""Predicts boxes.
Args:
roi_pooled_features: A float tensor of shape [batch_size, height, width,
features: A float tensor of shape [batch_size, height, width,
channels] containing features for a batch of images.
num_predictions_per_location: Int containing number of predictions per
location.
Returns:
box_encodings: A float tensor of shape
[batch_size, 1, num_classes, code_size] representing the location of the
objects.
Raises:
ValueError: If num_predictions_per_location is not 1.
"""
if num_predictions_per_location != 1:
raise ValueError('Only num_predictions_per_location=1 is supported')
spatial_averaged_roi_pooled_features = tf.reduce_mean(
roi_pooled_features, [1, 2], keep_dims=True, name='AvgPool')
features, [1, 2], keep_dims=True, name='AvgPool')
flattened_roi_pooled_features = slim.flatten(
spatial_averaged_roi_pooled_features)
if self._use_dropout:
......@@ -94,3 +110,130 @@ class BoxHead(mask_rcnn_head.MaskRCNNHead):
box_encodings = tf.reshape(box_encodings,
[-1, 1, number_of_boxes, self._box_code_size])
return box_encodings
class ConvolutionalBoxHead(head.Head):
"""Convolutional box prediction head."""
def __init__(self,
is_training,
box_code_size,
kernel_size,
use_depthwise=False):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
box_code_size: Size of encoding for each box.
kernel_size: Size of final convolution kernel. If the
spatial resolution of the feature map is smaller than the kernel size,
then the kernel size is automatically set to be
min(feature_width, feature_height).
use_depthwise: Whether to use depthwise convolutions for prediction
steps. Default is False.
Raises:
ValueError: if min_depth > max_depth.
"""
super(ConvolutionalBoxHead, self).__init__()
self._is_training = is_training
self._box_code_size = box_code_size
self._kernel_size = kernel_size
self._use_depthwise = use_depthwise
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location. Int specifying number of boxes per location.
Returns:
box_encodings: A float tensors of shape
[batch_size, num_anchors, q, code_size] representing the location of
the objects, where q is 1 or the number of classes.
"""
net = features
if self._use_depthwise:
box_encodings = slim.separable_conv2d(
net, None, [self._kernel_size, self._kernel_size],
padding='SAME', depth_multiplier=1, stride=1,
rate=1, scope='BoxEncodingPredictor_depthwise')
box_encodings = slim.conv2d(
box_encodings,
num_predictions_per_location * self._box_code_size, [1, 1],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='BoxEncodingPredictor')
else:
box_encodings = slim.conv2d(
net, num_predictions_per_location * self._box_code_size,
[self._kernel_size, self._kernel_size],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='BoxEncodingPredictor')
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
box_encodings = tf.reshape(box_encodings,
[batch_size, -1, 1, self._box_code_size])
return box_encodings
# TODO(alirezafathi): See if possible to unify Weight Shared with regular
# convolutional box head.
class WeightSharedConvolutionalBoxHead(head.Head):
"""Weight shared convolutional box prediction head.
This head allows sharing the same set of parameters (weights) when called more
then once on different feature maps.
"""
def __init__(self,
box_code_size,
kernel_size=3,
class_prediction_bias_init=0.0):
"""Constructor.
Args:
box_code_size: Size of encoding for each box.
kernel_size: Size of final convolution kernel.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
"""
super(WeightSharedConvolutionalBoxHead, self).__init__()
self._box_code_size = box_code_size
self._kernel_size = kernel_size
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location.
Returns:
box_encodings: A float tensor of shape
[batch_size, num_anchors, code_size] representing the location of
the objects.
"""
box_encodings_net = features
box_encodings = slim.conv2d(
box_encodings_net,
num_predictions_per_location * self._box_code_size,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
scope='BoxPredictor')
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
box_encodings = tf.reshape(box_encodings,
[batch_size, -1, self._box_code_size])
return box_encodings
......@@ -13,17 +13,17 @@
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.predictors.mask_rcnn_heads.box_head."""
"""Tests for object_detection.predictors.heads.box_head."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.predictors.mask_rcnn_heads import box_head
from object_detection.predictors.heads import box_head
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
class BoxHeadTest(test_case.TestCase):
class MaskRCNNBoxHeadTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(self,
op_type=hyperparams_pb2.Hyperparams.FC):
......@@ -44,7 +44,7 @@ class BoxHeadTest(test_case.TestCase):
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
box_prediction_head = box_head.BoxHead(
box_prediction_head = box_head.MaskRCNNBoxHead(
is_training=False,
num_classes=20,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -55,10 +55,73 @@ class BoxHeadTest(test_case.TestCase):
roi_pooled_features = tf.random_uniform(
[64, 7, 7, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
prediction = box_prediction_head.predict(
roi_pooled_features=roi_pooled_features)
tf.logging.info(prediction.shape)
features=roi_pooled_features, num_predictions_per_location=1)
self.assertAllEqual([64, 1, 20, 4], prediction.get_shape().as_list())
class ConvolutionalBoxPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
box_prediction_head = box_head.ConvolutionalBoxHead(
is_training=True,
box_code_size=4,
kernel_size=3)
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
box_encodings = box_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 1, 4], box_encodings.get_shape().as_list())
class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
box_prediction_head = box_head.WeightSharedConvolutionalBoxHead(
box_code_size=4)
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
box_encodings = box_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 4], box_encodings.get_shape().as_list())
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class Head.
Contains Class prediction head classes for different meta architectures.
All the class prediction heads have a predict function that receives the
`features` as the first argument and returns class predictions with background.
"""
import tensorflow as tf
from object_detection.predictors.heads import head
slim = tf.contrib.slim
class MaskRCNNClassHead(head.Head):
"""Mask RCNN class prediction head.
Please refer to Mask RCNN paper:
https://arxiv.org/abs/1703.06870
"""
def __init__(self, is_training, num_classes, fc_hyperparams_fn,
use_dropout, dropout_keep_prob):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
fc_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for fully connected ops.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
"""
super(MaskRCNNClassHead, self).__init__()
self._is_training = is_training
self._num_classes = num_classes
self._fc_hyperparams_fn = fc_hyperparams_fn
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
def predict(self, features, num_predictions_per_location=1):
"""Predicts boxes and class scores.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing features for a batch of images.
num_predictions_per_location: Int containing number of predictions per
location.
Returns:
class_predictions_with_background: A float tensor of shape
[batch_size, 1, num_classes + 1] representing the class predictions for
the proposals.
Raises:
ValueError: If num_predictions_per_location is not 1.
"""
if num_predictions_per_location != 1:
raise ValueError('Only num_predictions_per_location=1 is supported')
spatial_averaged_roi_pooled_features = tf.reduce_mean(
features, [1, 2], keep_dims=True, name='AvgPool')
flattened_roi_pooled_features = slim.flatten(
spatial_averaged_roi_pooled_features)
if self._use_dropout:
flattened_roi_pooled_features = slim.dropout(
flattened_roi_pooled_features,
keep_prob=self._dropout_keep_prob,
is_training=self._is_training)
with slim.arg_scope(self._fc_hyperparams_fn()):
class_predictions_with_background = slim.fully_connected(
flattened_roi_pooled_features,
self._num_classes + 1,
activation_fn=None,
scope='ClassPredictor')
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [-1, 1, self._num_classes + 1])
return class_predictions_with_background
class ConvolutionalClassHead(head.Head):
"""Convolutional class prediction head."""
def __init__(self,
is_training,
num_classes,
use_dropout,
dropout_keep_prob,
kernel_size,
apply_sigmoid_to_scores=False,
class_prediction_bias_init=0.0,
use_depthwise=False):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: Number of classes.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
kernel_size: Size of final convolution kernel. If the
spatial resolution of the feature map is smaller than the kernel size,
then the kernel size is automatically set to be
min(feature_width, feature_height).
apply_sigmoid_to_scores: if True, apply the sigmoid on the output
class_predictions.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
use_depthwise: Whether to use depthwise convolutions for prediction
steps. Default is False.
Raises:
ValueError: if min_depth > max_depth.
"""
super(ConvolutionalClassHead, self).__init__()
self._is_training = is_training
self._num_classes = num_classes
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
self._kernel_size = kernel_size
self._apply_sigmoid_to_scores = apply_sigmoid_to_scores
self._class_prediction_bias_init = class_prediction_bias_init
self._use_depthwise = use_depthwise
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location.
Returns:
class_predictions_with_background: A float tensors of shape
[batch_size, num_anchors, num_classes + 1] representing the class
predictions for the proposals.
"""
net = features
# Add a slot for the background class.
num_class_slots = self._num_classes + 1
if self._use_dropout:
net = slim.dropout(net, keep_prob=self._dropout_keep_prob)
if self._use_depthwise:
class_predictions_with_background = slim.separable_conv2d(
net, None, [self._kernel_size, self._kernel_size],
padding='SAME', depth_multiplier=1, stride=1,
rate=1, scope='ClassPredictor_depthwise')
class_predictions_with_background = slim.conv2d(
class_predictions_with_background,
num_predictions_per_location * num_class_slots, [1, 1],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='ClassPredictor')
else:
class_predictions_with_background = slim.conv2d(
net,
num_predictions_per_location * num_class_slots,
[self._kernel_size, self._kernel_size],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='ClassPredictor',
biases_initializer=tf.constant_initializer(
self._class_prediction_bias_init))
if self._apply_sigmoid_to_scores:
class_predictions_with_background = tf.sigmoid(
class_predictions_with_background)
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [batch_size, -1, num_class_slots])
return class_predictions_with_background
# TODO(alirezafathi): See if possible to unify Weight Shared with regular
# convolutional class head.
class WeightSharedConvolutionalClassHead(head.Head):
"""Weight shared convolutional class prediction head.
This head allows sharing the same set of parameters (weights) when called more
then once on different feature maps.
"""
def __init__(self,
num_classes,
kernel_size=3,
class_prediction_bias_init=0.0,
use_dropout=False,
dropout_keep_prob=0.8):
"""Constructor.
Args:
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
kernel_size: Size of final convolution kernel.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
use_dropout: Whether to apply dropout to class prediction head.
dropout_keep_prob: Probability of keeping activiations.
"""
super(WeightSharedConvolutionalClassHead, self).__init__()
self._num_classes = num_classes
self._kernel_size = kernel_size
self._class_prediction_bias_init = class_prediction_bias_init
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location.
Returns:
class_predictions_with_background: A tensor of shape
[batch_size, num_anchors, num_classes + 1] representing the class
predictions for the proposals.
"""
class_predictions_net = features
num_class_slots = self._num_classes + 1
# Add a slot for the background class.
if self._use_dropout:
class_predictions_net = slim.dropout(
class_predictions_net, keep_prob=self._dropout_keep_prob)
class_predictions_with_background = slim.conv2d(
class_predictions_net,
num_predictions_per_location * num_class_slots,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
biases_initializer=tf.constant_initializer(
self._class_prediction_bias_init),
scope='ClassPredictor')
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [batch_size, -1, num_class_slots])
return class_predictions_with_background
......@@ -13,17 +13,17 @@
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.predictors.mask_rcnn_heads.class_head."""
"""Tests for object_detection.predictors.heads.class_head."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.predictors.mask_rcnn_heads import class_head
from object_detection.predictors.heads import class_head
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
class ClassHeadTest(test_case.TestCase):
class MaskRCNNClassHeadTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(self,
op_type=hyperparams_pb2.Hyperparams.FC):
......@@ -44,7 +44,7 @@ class ClassHeadTest(test_case.TestCase):
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
class_prediction_head = class_head.ClassHead(
class_prediction_head = class_head.MaskRCNNClassHead(
is_training=False,
num_classes=20,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -53,10 +53,76 @@ class ClassHeadTest(test_case.TestCase):
roi_pooled_features = tf.random_uniform(
[64, 7, 7, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
prediction = class_prediction_head.predict(
roi_pooled_features=roi_pooled_features)
tf.logging.info(prediction.shape)
features=roi_pooled_features, num_predictions_per_location=1)
self.assertAllEqual([64, 1, 21], prediction.get_shape().as_list())
class ConvolutionalClassPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
class_prediction_head = class_head.ConvolutionalClassHead(
is_training=True,
num_classes=20,
use_dropout=True,
dropout_keep_prob=0.5,
kernel_size=3)
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
class_predictions = class_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 21],
class_predictions.get_shape().as_list())
class WeightSharedConvolutionalClassPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
class_prediction_head = (
class_head.WeightSharedConvolutionalClassHead(num_classes=20))
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
class_predictions = class_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 21], class_predictions.get_shape().as_list())
if __name__ == '__main__':
tf.test.main()
......@@ -13,32 +13,47 @@
# limitations under the License.
# ==============================================================================
"""Base Mask RCNN head class."""
"""Base head class.
All the different kinds of prediction heads in different models will inherit
from this class. What is in common between all head classes is that they have a
`predict` function that receives `features` as its first argument.
How to add a new prediction head to an existing meta architecture?
For example, how can we add a `3d shape` prediction head to Mask RCNN?
We have to take the following steps to add a new prediction head to an
existing meta arch:
(a) Add a class for predicting the head. This class should inherit from the
`Head` class below and have a `predict` function that receives the features
and predicts the output. The output is always a tf.float32 tensor.
(b) Add the head to the meta architecture. For example in case of Mask RCNN,
go to box_predictor_builder and put in the logic for adding the new head to the
Mask RCNN box predictor.
(c) Add the logic for computing the loss for the new head.
(d) Add the necessary metrics for the new head.
(e) (optional) Add visualization for the new head.
"""
from abc import abstractmethod
class MaskRCNNHead(object):
class Head(object):
"""Mask RCNN head base class."""
def __init__(self):
"""Constructor."""
pass
def predict(self, roi_pooled_features):
@abstractmethod
def predict(self, features, num_predictions_per_location):
"""Returns the head's predictions.
Args:
roi_pooled_features: A float tensor of shape
[batch_size, height, width, channels] containing ROI pooled features
from a batch of boxes.
"""
return self._predict(roi_pooled_features)
features: A float tensor of features.
num_predictions_per_location: Int containing number of predictions per
location.
@abstractmethod
def _predict(self, roi_pooled_features):
"""The abstract internal prediction function that needs to be overloaded.
Args:
roi_pooled_features: A float tensor of shape
[batch_size, height, width, channels] containing ROI pooled features
from a batch of boxes.
Returns:
A tf.float32 tensor.
"""
pass
......@@ -13,15 +13,27 @@
# limitations under the License.
# ==============================================================================
"""Mask R-CNN Keypoint Head."""
"""Keypoint Head.
Contains Keypoint prediction head classes for different meta architectures.
All the keypoint prediction heads have a predict function that receives the
`features` as the first argument and returns `keypoint_predictions`.
Keypoints could be used to represent the human body joint locations as in
Mask RCNN paper. Or they could be used to represent different part locations of
objects.
"""
import tensorflow as tf
from object_detection.predictors.mask_rcnn_heads import mask_rcnn_head
from object_detection.predictors.heads import head
slim = tf.contrib.slim
class KeypointHead(mask_rcnn_head.MaskRCNNHead):
"""Mask RCNN keypoint prediction head."""
class MaskRCNNKeypointHead(head.Head):
"""Mask RCNN keypoint prediction head.
Please refer to Mask RCNN paper:
https://arxiv.org/abs/1703.06870
"""
def __init__(self,
num_keypoints=17,
......@@ -48,7 +60,7 @@ class KeypointHead(mask_rcnn_head.MaskRCNNHead):
based on the number of object classes and the number of channels in the
image features.
"""
super(KeypointHead, self).__init__()
super(MaskRCNNKeypointHead, self).__init__()
self._num_keypoints = num_keypoints
self._conv_hyperparams_fn = conv_hyperparams_fn
self._keypoint_heatmap_height = keypoint_heatmap_height
......@@ -57,20 +69,27 @@ class KeypointHead(mask_rcnn_head.MaskRCNNHead):
keypoint_prediction_num_conv_layers)
self._keypoint_prediction_conv_depth = keypoint_prediction_conv_depth
def _predict(self, roi_pooled_features):
def predict(self, features, num_predictions_per_location=1):
"""Performs keypoint prediction.
Args:
roi_pooled_features: A float tensor of shape [batch_size, height, width,
features: A float tensor of shape [batch_size, height, width,
channels] containing features for a batch of images.
num_predictions_per_location: Int containing number of predictions per
location.
Returns:
instance_masks: A float tensor of shape
[batch_size, 1, num_keypoints, heatmap_height, heatmap_width].
Raises:
ValueError: If num_predictions_per_location is not 1.
"""
if num_predictions_per_location != 1:
raise ValueError('Only num_predictions_per_location=1 is supported')
with slim.arg_scope(self._conv_hyperparams_fn()):
net = slim.conv2d(
roi_pooled_features,
features,
self._keypoint_prediction_conv_depth, [3, 3],
scope='conv_1')
for i in range(1, self._keypoint_prediction_num_conv_layers):
......
......@@ -13,17 +13,17 @@
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.predictors.mask_rcnn_heads.keypoint_head."""
"""Tests for object_detection.predictors.heads.keypoint_head."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.predictors.mask_rcnn_heads import keypoint_head
from object_detection.predictors.heads import keypoint_head
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
class KeypointHeadTest(test_case.TestCase):
class MaskRCNNKeypointHeadTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(self,
op_type=hyperparams_pb2.Hyperparams.FC):
......@@ -44,13 +44,12 @@ class KeypointHeadTest(test_case.TestCase):
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
keypoint_prediction_head = keypoint_head.KeypointHead(
keypoint_prediction_head = keypoint_head.MaskRCNNKeypointHead(
conv_hyperparams_fn=self._build_arg_scope_with_hyperparams())
roi_pooled_features = tf.random_uniform(
[64, 14, 14, 1024], minval=-2.0, maxval=2.0, dtype=tf.float32)
prediction = keypoint_prediction_head.predict(
roi_pooled_features=roi_pooled_features)
tf.logging.info(prediction.shape)
features=roi_pooled_features, num_predictions_per_location=1)
self.assertAllEqual([64, 1, 17, 56, 56], prediction.get_shape().as_list())
......
......@@ -13,17 +13,26 @@
# limitations under the License.
# ==============================================================================
"""Mask R-CNN Mask Head."""
"""Mask Head.
Contains Mask prediction head classes for different meta architectures.
All the mask prediction heads have a predict function that receives the
`features` as the first argument and returns `mask_predictions`.
"""
import math
import tensorflow as tf
from object_detection.predictors.mask_rcnn_heads import mask_rcnn_head
from object_detection.predictors.heads import head
slim = tf.contrib.slim
class MaskHead(mask_rcnn_head.MaskRCNNHead):
"""Mask RCNN mask prediction head."""
class MaskRCNNMaskHead(head.Head):
"""Mask RCNN mask prediction head.
Please refer to Mask RCNN paper:
https://arxiv.org/abs/1703.06870
"""
def __init__(self,
num_classes,
......@@ -57,7 +66,7 @@ class MaskHead(mask_rcnn_head.MaskRCNNHead):
Raises:
ValueError: conv_hyperparams_fn is None.
"""
super(MaskHead, self).__init__()
super(MaskRCNNMaskHead, self).__init__()
self._num_classes = num_classes
self._conv_hyperparams_fn = conv_hyperparams_fn
self._mask_height = mask_height
......@@ -102,25 +111,32 @@ class MaskHead(mask_rcnn_head.MaskRCNNHead):
total_weight)
return int(math.pow(2.0, num_conv_channels_log))
def _predict(self, roi_pooled_features):
def predict(self, features, num_predictions_per_location=1):
"""Performs mask prediction.
Args:
roi_pooled_features: A float tensor of shape [batch_size, height, width,
channels] containing features for a batch of images.
features: A float tensor of shape [batch_size, height, width, channels]
containing features for a batch of images.
num_predictions_per_location: Int containing number of predictions per
location.
Returns:
instance_masks: A float tensor of shape
[batch_size, 1, num_classes, mask_height, mask_width].
Raises:
ValueError: If num_predictions_per_location is not 1.
"""
if num_predictions_per_location != 1:
raise ValueError('Only num_predictions_per_location=1 is supported')
num_conv_channels = self._mask_prediction_conv_depth
if num_conv_channels == 0:
num_feature_channels = roi_pooled_features.get_shape().as_list()[3]
num_feature_channels = features.get_shape().as_list()[3]
num_conv_channels = self._get_mask_predictor_conv_depth(
num_feature_channels, self._num_classes)
with slim.arg_scope(self._conv_hyperparams_fn()):
upsampled_features = tf.image.resize_bilinear(
roi_pooled_features, [self._mask_height, self._mask_width],
features, [self._mask_height, self._mask_width],
align_corners=True)
for _ in range(self._mask_prediction_num_conv_layers - 1):
upsampled_features = slim.conv2d(
......@@ -137,3 +153,182 @@ class MaskHead(mask_rcnn_head.MaskRCNNHead):
tf.transpose(mask_predictions, perm=[0, 3, 1, 2]),
axis=1,
name='MaskPredictor')
class ConvolutionalMaskHead(head.Head):
"""Convolutional class prediction head."""
def __init__(self,
is_training,
num_classes,
use_dropout,
dropout_keep_prob,
kernel_size,
use_depthwise=False,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=False):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: Number of classes.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
kernel_size: Size of final convolution kernel. If the
spatial resolution of the feature map is smaller than the kernel size,
then the kernel size is automatically set to be
min(feature_width, feature_height).
use_depthwise: Whether to use depthwise convolutions for prediction
steps. Default is False.
mask_height: Desired output mask height. The default value is 7.
mask_width: Desired output mask width. The default value is 7.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
Raises:
ValueError: if min_depth > max_depth.
"""
super(ConvolutionalMaskHead, self).__init__()
self._is_training = is_training
self._num_classes = num_classes
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
self._kernel_size = kernel_size
self._use_depthwise = use_depthwise
self._mask_height = mask_height
self._mask_width = mask_width
self._masks_are_class_agnostic = masks_are_class_agnostic
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location.
Returns:
mask_predictions: A float tensors of shape
[batch_size, num_anchors, num_masks, mask_height, mask_width]
representing the mask predictions for the proposals.
"""
image_feature = features
# Add a slot for the background class.
if self._masks_are_class_agnostic:
num_masks = 1
else:
num_masks = self._num_classes
num_mask_channels = num_masks * self._mask_height * self._mask_width
net = image_feature
if self._use_dropout:
net = slim.dropout(net, keep_prob=self._dropout_keep_prob)
if self._use_depthwise:
mask_predictions = slim.separable_conv2d(
net, None, [self._kernel_size, self._kernel_size],
padding='SAME', depth_multiplier=1, stride=1,
rate=1, scope='MaskPredictor_depthwise')
mask_predictions = slim.conv2d(
mask_predictions,
num_predictions_per_location * num_mask_channels,
[1, 1],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='MaskPredictor')
else:
mask_predictions = slim.conv2d(
net,
num_predictions_per_location * num_mask_channels,
[self._kernel_size, self._kernel_size],
activation_fn=None,
normalizer_fn=None,
normalizer_params=None,
scope='MaskPredictor')
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
mask_predictions = tf.reshape(
mask_predictions,
[batch_size, -1, num_masks, self._mask_height, self._mask_width])
return mask_predictions
# TODO(alirezafathi): See if possible to unify Weight Shared with regular
# convolutional mask head.
class WeightSharedConvolutionalMaskHead(head.Head):
"""Weight shared convolutional mask prediction head."""
def __init__(self,
num_classes,
kernel_size=3,
use_dropout=False,
dropout_keep_prob=0.8,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=False):
"""Constructor.
Args:
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
kernel_size: Size of final convolution kernel.
use_dropout: Whether to apply dropout to class prediction head.
dropout_keep_prob: Probability of keeping activiations.
mask_height: Desired output mask height. The default value is 7.
mask_width: Desired output mask width. The default value is 7.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
"""
super(WeightSharedConvolutionalMaskHead, self).__init__()
self._num_classes = num_classes
self._kernel_size = kernel_size
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
self._mask_height = mask_height
self._mask_width = mask_width
self._masks_are_class_agnostic = masks_are_class_agnostic
def predict(self, features, num_predictions_per_location):
"""Predicts boxes.
Args:
features: A float tensor of shape [batch_size, height, width, channels]
containing image features.
num_predictions_per_location: Number of box predictions to be made per
spatial location.
Returns:
mask_predictions: A tensor of shape
[batch_size, num_anchors, num_classes, mask_height, mask_width]
representing the mask predictions for the proposals.
"""
mask_predictions_net = features
if self._masks_are_class_agnostic:
num_masks = 1
else:
num_masks = self._num_classes
num_mask_channels = num_masks * self._mask_height * self._mask_width
if self._use_dropout:
mask_predictions_net = slim.dropout(
mask_predictions_net, keep_prob=self._dropout_keep_prob)
mask_predictions = slim.conv2d(
mask_predictions_net,
num_predictions_per_location * num_mask_channels,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
scope='MaskPredictor')
batch_size = features.get_shape().as_list()[0]
if batch_size is None:
batch_size = tf.shape(features)[0]
mask_predictions = tf.reshape(
mask_predictions,
[batch_size, -1, num_masks, self._mask_height, self._mask_width])
return mask_predictions
......@@ -13,17 +13,17 @@
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.predictors.mask_rcnn_heads.mask_head."""
"""Tests for object_detection.predictors.heads.mask_head."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.predictors.mask_rcnn_heads import mask_head
from object_detection.predictors.heads import mask_head
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
class MaskHeadTest(test_case.TestCase):
class MaskRCNNMaskHeadTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(self,
op_type=hyperparams_pb2.Hyperparams.FC):
......@@ -44,7 +44,7 @@ class MaskHeadTest(test_case.TestCase):
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
mask_prediction_head = mask_head.MaskHead(
mask_prediction_head = mask_head.MaskRCNNMaskHead(
num_classes=20,
conv_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
mask_height=14,
......@@ -55,10 +55,115 @@ class MaskHeadTest(test_case.TestCase):
roi_pooled_features = tf.random_uniform(
[64, 7, 7, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
prediction = mask_prediction_head.predict(
roi_pooled_features=roi_pooled_features)
tf.logging.info(prediction.shape)
features=roi_pooled_features, num_predictions_per_location=1)
self.assertAllEqual([64, 1, 20, 14, 14], prediction.get_shape().as_list())
class ConvolutionalMaskPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
mask_prediction_head = mask_head.ConvolutionalMaskHead(
is_training=True,
num_classes=20,
use_dropout=True,
dropout_keep_prob=0.5,
kernel_size=3,
mask_height=7,
mask_width=7)
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
mask_predictions = mask_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 20, 7, 7],
mask_predictions.get_shape().as_list())
def test_class_agnostic_prediction_size(self):
mask_prediction_head = mask_head.ConvolutionalMaskHead(
is_training=True,
num_classes=20,
use_dropout=True,
dropout_keep_prob=0.5,
kernel_size=3,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=True)
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
mask_predictions = mask_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 1, 7, 7],
mask_predictions.get_shape().as_list())
class WeightSharedConvolutionalMaskPredictorTest(test_case.TestCase):
def _build_arg_scope_with_hyperparams(
self, op_type=hyperparams_pb2.Hyperparams.CONV):
hyperparams = hyperparams_pb2.Hyperparams()
hyperparams_text_proto = """
activation: NONE
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
text_format.Merge(hyperparams_text_proto, hyperparams)
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def test_prediction_size(self):
mask_prediction_head = (
mask_head.WeightSharedConvolutionalMaskHead(
num_classes=20,
mask_height=7,
mask_width=7))
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
mask_predictions = mask_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 20, 7, 7],
mask_predictions.get_shape().as_list())
def test_class_agnostic_prediction_size(self):
mask_prediction_head = (
mask_head.WeightSharedConvolutionalMaskHead(
num_classes=20,
mask_height=7,
mask_width=7,
masks_are_class_agnostic=True))
image_feature = tf.random_uniform(
[64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32)
mask_predictions = mask_prediction_head.predict(
features=image_feature,
num_predictions_per_location=1)
self.assertAllEqual([64, 323, 1, 7, 7],
mask_predictions.get_shape().as_list())
if __name__ == '__main__':
tf.test.main()
......@@ -126,15 +126,18 @@ class MaskRCNNBoxPredictor(box_predictor.BoxPredictor):
if prediction_stage == 2:
predictions_dict[BOX_ENCODINGS] = self._box_prediction_head.predict(
roi_pooled_features=image_feature)
features=image_feature,
num_predictions_per_location=num_predictions_per_location[0])
predictions_dict[CLASS_PREDICTIONS_WITH_BACKGROUND] = (
self._class_prediction_head.predict(roi_pooled_features=image_feature)
)
self._class_prediction_head.predict(
features=image_feature,
num_predictions_per_location=num_predictions_per_location[0]))
elif prediction_stage == 3:
for prediction_head in self.get_third_stage_prediction_heads():
head_object = self._third_stage_heads[prediction_head]
predictions_dict[prediction_head] = head_object.predict(
roi_pooled_features=image_feature)
features=image_feature,
num_predictions_per_location=num_predictions_per_location[0])
else:
raise ValueError('prediction_stage should be either 2 or 3.')
......
......@@ -18,11 +18,9 @@ import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.predictors import mask_rcnn_box_predictor as box_predictor
from object_detection.predictors.mask_rcnn_heads import box_head
from object_detection.predictors.mask_rcnn_heads import class_head
from object_detection.predictors.mask_rcnn_heads import mask_head
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
......@@ -47,45 +45,9 @@ class MaskRCNNBoxPredictorTest(test_case.TestCase):
hyperparams.op = op_type
return hyperparams_builder.build(hyperparams, is_training=True)
def _box_predictor_builder(self,
is_training,
num_classes,
fc_hyperparams_fn,
use_dropout,
dropout_keep_prob,
box_code_size,
share_box_across_classes=False,
conv_hyperparams_fn=None,
predict_instance_masks=False):
box_prediction_head = box_head.BoxHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob,
box_code_size=box_code_size,
share_box_across_classes=share_box_across_classes)
class_prediction_head = class_head.ClassHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=use_dropout,
dropout_keep_prob=dropout_keep_prob)
third_stage_heads = {}
if predict_instance_masks:
third_stage_heads[box_predictor.MASK_PREDICTIONS] = mask_head.MaskHead(
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn)
return box_predictor.MaskRCNNBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
third_stage_heads=third_stage_heads)
def test_get_boxes_with_five_classes(self):
def graph_fn(image_features):
mask_box_predictor = self._box_predictor_builder(
mask_box_predictor = box_predictor_builder.build_mask_rcnn_box_predictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -109,7 +71,7 @@ class MaskRCNNBoxPredictorTest(test_case.TestCase):
def test_get_boxes_with_five_classes_share_box_across_classes(self):
def graph_fn(image_features):
mask_box_predictor = self._box_predictor_builder(
mask_box_predictor = box_predictor_builder.build_mask_rcnn_box_predictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -134,7 +96,7 @@ class MaskRCNNBoxPredictorTest(test_case.TestCase):
def test_value_error_on_predict_instance_masks_with_no_conv_hyperparms(self):
with self.assertRaises(ValueError):
self._box_predictor_builder(
box_predictor_builder.build_mask_rcnn_box_predictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -145,7 +107,7 @@ class MaskRCNNBoxPredictorTest(test_case.TestCase):
def test_get_instance_masks(self):
def graph_fn(image_features):
mask_box_predictor = self._box_predictor_builder(
mask_box_predictor = box_predictor_builder.build_mask_rcnn_box_predictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......@@ -167,7 +129,7 @@ class MaskRCNNBoxPredictorTest(test_case.TestCase):
def test_do_not_return_instance_masks_without_request(self):
image_features = tf.random_uniform([2, 7, 7, 3], dtype=tf.float32)
mask_box_predictor = self._box_predictor_builder(
mask_box_predictor = box_predictor_builder.build_mask_rcnn_box_predictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mask R-CNN Class Head."""
import tensorflow as tf
from object_detection.predictors.mask_rcnn_heads import mask_rcnn_head
slim = tf.contrib.slim
class ClassHead(mask_rcnn_head.MaskRCNNHead):
"""Mask RCNN class prediction head."""
def __init__(self, is_training, num_classes, fc_hyperparams_fn,
use_dropout, dropout_keep_prob):
"""Constructor.
Args:
is_training: Indicates whether the BoxPredictor is in training mode.
num_classes: number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
fc_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for fully connected ops.
use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True.
"""
super(ClassHead, self).__init__()
self._is_training = is_training
self._num_classes = num_classes
self._fc_hyperparams_fn = fc_hyperparams_fn
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob
def _predict(self, roi_pooled_features):
"""Predicts boxes and class scores.
Args:
roi_pooled_features: A float tensor of shape [batch_size, height, width,
channels] containing features for a batch of images.
Returns:
class_predictions_with_background: A float tensor of shape
[batch_size, 1, num_classes + 1] representing the class predictions for
the proposals.
"""
spatial_averaged_roi_pooled_features = tf.reduce_mean(
roi_pooled_features, [1, 2], keep_dims=True, name='AvgPool')
flattened_roi_pooled_features = slim.flatten(
spatial_averaged_roi_pooled_features)
if self._use_dropout:
flattened_roi_pooled_features = slim.dropout(
flattened_roi_pooled_features,
keep_prob=self._dropout_keep_prob,
is_training=self._is_training)
with slim.arg_scope(self._fc_hyperparams_fn()):
class_predictions_with_background = slim.fully_connected(
flattened_roi_pooled_features,
self._num_classes + 1,
activation_fn=None,
scope='ClassPredictor')
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [-1, 1, self._num_classes + 1])
return class_predictions_with_background
......@@ -79,10 +79,9 @@ message InputReader {
// Number of groundtruth keypoints per object.
optional uint32 num_keypoints = 16 [default = 0];
// Maximum number of boxes to pad to during training.
// Set this to at least the maximum amount of boxes in the input data.
// Otherwise, it may cause "Data loss: Attempted to pad to a smaller size
// than the input element" errors.
// Maximum number of boxes to pad to during training / evaluation.
// Set this to at least the maximum amount of boxes in the input data,
// otherwise some groundtruth boxes may be clipped.
optional int32 max_number_of_boxes = 21 [default=100];
// Whether to load groundtruth instance masks.
......
......@@ -12,6 +12,7 @@ import "object_detection/protos/post_processing.proto";
import "object_detection/protos/region_similarity_calculator.proto";
// Configuration for Single Shot Detection (SSD) models.
// Next id: 21
message Ssd {
// Number of classes to predict.
......@@ -80,6 +81,22 @@ message Ssd {
// a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order
// to update the batch norm moving average parameters.
optional bool inplace_batchnorm_update = 15 [default = false];
// Whether to weight the regression loss by the score of the ground truth box
// the anchor matches to.
optional bool weight_regression_loss_by_score = 17 [default=false];
// Whether to compute expected loss with respect to balanced positive/negative
// sampling scheme. If false, use explicit sampling.
optional bool use_expected_classification_loss_under_sampling = 18 [default=false];
// Minimum number of effective negative samples.
// Only applies if use_expected_classification_loss_under_sampling is true.
optional float minimum_negative_sampling = 19 [default=0];
// Desired number of effective negative samples per positive sample.
// Only applies if use_expected_classification_loss_under_sampling is true.
optional float desired_negative_sampling_ratio = 20 [default=3];
}
......@@ -147,3 +164,4 @@ message FeaturePyramidNetworks {
// maximum level in feature pyramid
optional int32 max_level = 2 [default = 7];
}
......@@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto";
import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py).
// Next id: 25
// Next id: 26
message TrainConfig {
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
......@@ -61,7 +61,13 @@ message TrainConfig {
// amount.
optional float bias_grad_multiplier = 11 [default=0];
// Variables that should not be updated during training.
// Variables that should be updated during training. Note that variables which
// also match the patterns in freeze_variables will be excluded.
repeated string update_trainable_variables = 25;
// Variables that should not be updated during training. If
// update_trainable_variables is not empty, only eliminates the included
// variables according to freeze_variables patterns.
repeated string freeze_variables = 12;
// Number of replicas to aggregate before making parameter updates.
......
......@@ -91,6 +91,23 @@ class DetectionEvaluator(object):
"""
pass
def get_estimator_eval_metric_ops(self, eval_dict):
"""Returns dict of metrics to use with `tf.estimator.EstimatorSpec`.
Note that this must only be implemented if performing evaluation with a
`tf.estimator.Estimator`.
Args:
eval_dict: A dictionary that holds tensors for evaluating an object
detection model, returned from
eval_util.result_dict_for_single_example().
Returns:
A dictionary of metric names to tuple of value_op and update_op that can
be used as eval metric ops in `tf.estimator.EstimatorSpec`.
"""
pass
@abstractmethod
def evaluate(self):
"""Evaluates detections and returns a dictionary of metrics."""
......
......@@ -1008,15 +1008,15 @@ def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
negative_to_positive_ratio,
desired_negative_sampling_ratio,
minimum_negative_sampling):
"""Computes classification loss by background/foreground weighting.
The weighting is such that the effective background/foreground weight ratio
is the negative_to_positive_ratio. if p_i is the foreground probability of
anchor a_i, L(a_i) is the anchors loss, N is the number of anchors, and M is
the sum of foreground probabilities across anchors, then the total loss L is
calculated as:
is the desired_negative_sampling_ratio. if p_i is the foreground probability
of anchor a_i, L(a_i) is the anchors loss, N is the number of anchors, and M
is the sum of foreground probabilities across anchors, then the total loss L
is calculated as:
beta = K*M/(N-M)
L = sum_{i=1}^N [p_i + beta * (1 - p_i)] * (L(a_i))
......@@ -1027,14 +1027,14 @@ def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
the class distrubution for the target assigned to a given anchor.
cls_losses: Float tensor of shape [batch_size, num_anchors]
representing anchorwise classification losses.
negative_to_positive_ratio: The desired background/foreground weight ratio.
desired_negative_sampling_ratio: The desired background/foreground weight
ratio.
minimum_negative_sampling: Minimum number of effective negative samples.
Used only when there are no positive examples.
Returns:
The classification loss.
"""
num_anchors = tf.cast(tf.shape(batch_cls_targets)[1], tf.float32)
# find the p_i
......@@ -1042,7 +1042,7 @@ def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
foreground_probabilities_from_targets(batch_cls_targets))
foreground_sum = tf.reduce_sum(foreground_probabilities, axis=-1)
k = negative_to_positive_ratio
k = desired_negative_sampling_ratio
# compute beta
denominators = (num_anchors - foreground_sum)
......@@ -1053,7 +1053,8 @@ def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses,
# where the foreground sum is zero, use a minimum negative weight.
min_negative_weight = 1.0 * minimum_negative_sampling / num_anchors
beta = tf.where(
tf.equal(beta, 0), min_negative_weight * tf.ones_like(beta), beta)
tf.equal(foreground_sum, 0), min_negative_weight * tf.ones_like(beta),
beta)
beta = tf.reshape(beta, [-1, 1])
cls_loss_weights = foreground_probabilities + (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册