提交 324d6dc3 编写于 作者: Z Zhichao Lu 提交者: pkulzc

Merged commit includes the following changes:

196161788  by Zhichao Lu:

    Add eval_on_train_steps parameter.

    Since the number of samples in train dataset is usually different to the number of samples in the eval dataset.

--
196151742  by Zhichao Lu:

    Add an optional random sampling process for SSD meta arch and update mean stddev coder to use default std dev when corresponding tensor is not added to boxlist field.

--
196148940  by Zhichao Lu:

    Release ssdlite mobilenet v2 coco trained model.

--
196058528  by Zhichao Lu:

    Apply FPN feature map generation before we add additional layers on top of resnet feature extractor.

--
195818367  by Zhichao Lu:

    Add support for exporting detection keypoints.

--
195745420  by Zhichao Lu:

    Introduce include_metrics_per_category option to Object Detection eval_config.

--
195734733  by Zhichao Lu:

    Rename SSDLite config to be more explicit.

--
195717383  by Zhichao Lu:

    Add quantized training to object_detection.

--
195683542  by Zhichao Lu:

    Fix documentation for the interaction of fine_tune_checkpoint_type and load_all_detection_checkpoint_vars interaction.

--
195668233  by Zhichao Lu:

    Using batch size from params dictionary if present.

--
195570173  by Zhichao Lu:

    A few fixes to get new estimator API eval to match legacy detection eval binary by (1) plumbing `is_crowd` annotations through to COCO evaluator, (2) setting the `sloppy` flag in tf.contrib.data.parallel_interleave based on whether shuffling is enabled, and (3) saving the original image instead of the resized original image, which allows for small/medium/large mAP metrics to be properly computed.

--
195316756  by Zhichao Lu:

    Internal change

--

PiperOrigin-RevId: 196161788
上级 63054210
......@@ -25,6 +25,14 @@ from object_detection.core import box_list
class MeanStddevBoxCoder(box_coder.BoxCoder):
"""Mean stddev box coder."""
def __init__(self, stddev=0.01):
"""Constructor for MeanStddevBoxCoder.
Args:
stddev: The standard deviation used to encode and decode boxes.
"""
self._stddev = stddev
@property
def code_size(self):
return 4
......@@ -34,37 +42,38 @@ class MeanStddevBoxCoder(box_coder.BoxCoder):
Args:
boxes: BoxList holding N boxes to be encoded.
anchors: BoxList of N anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of N anchors.
Returns:
a tensor representing N anchor-encoded boxes
Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
box_corners = boxes.get()
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
means = anchors.get()
stddev = anchors.get_field('stddev')
return (box_corners - means) / stddev
return (box_corners - means) / self._stddev
def _decode(self, rel_codes, anchors):
"""Decode.
Args:
rel_codes: a tensor representing N anchor-encoded boxes.
anchors: BoxList of anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of anchors.
Returns:
boxes: BoxList holding N bounding boxes
Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field and expects
the decode method to use stddev value from that field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
means = anchors.get()
stddevs = anchors.get_field('stddev')
box_corners = rel_codes * stddevs + means
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
box_corners = rel_codes * self._stddev + means
return box_list.BoxList(box_corners)
......@@ -28,11 +28,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase):
boxes = box_list.BoxList(tf.constant(box_corners))
expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)
coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
rel_codes = coder.encode(boxes, priors)
with self.test_session() as sess:
rel_codes_out = sess.run(rel_codes)
......@@ -42,11 +40,9 @@ class MeanStddevBoxCoderTest(tf.test.TestCase):
rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]])
expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)
coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
decoded_boxes = coder.decode(rel_codes, priors)
decoded_box_corners = decoded_boxes.get()
with self.test_session() as sess:
......
......@@ -55,7 +55,8 @@ def build(box_coder_config):
])
if (box_coder_config.WhichOneof('box_coder_oneof') ==
'mean_stddev_box_coder'):
return mean_stddev_box_coder.MeanStddevBoxCoder()
return mean_stddev_box_coder.MeanStddevBoxCoder(
stddev=box_coder_config.mean_stddev_box_coder.stddev)
if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
return square_box_coder.SquareBoxCoder(scale_factors=[
box_coder_config.square_box_coder.y_scale,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for quantized training and evaluation."""
import tensorflow as tf
def build(graph_rewriter_config, is_training):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
is_training: whether in training of eval mode.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')
# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
tf.contrib.quantize.create_training_graph(
input_graph=tf.get_default_graph(),
quant_delay=graph_rewriter_config.quantization.delay)
else:
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())
tf.contrib.layers.summarize_collection('quant_vars')
return graph_rewrite_fn
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_rewriter_builder."""
import mock
import tensorflow as tf
from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2
class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')
def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
with mock.patch.object(tf.contrib.quantize,
'create_eval_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=False)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
mock_summarize_col.assert_called_with('quant_vars')
if __name__ == '__main__':
tf.test.main()
......@@ -15,6 +15,7 @@
"""A function to build localization and classification losses from config."""
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import losses
from object_detection.protos import losses_pb2
......@@ -34,9 +35,12 @@ def build(loss_config):
classification_weight: Classification loss weight.
localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object.
random_example_sampler: BalancedPositiveNegativeSampler object.
Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
ValueError: If random_example_sampler is getting non-positive value as
desired positive example fraction.
"""
classification_loss = _build_classification_loss(
loss_config.classification_loss)
......@@ -54,9 +58,16 @@ def build(loss_config):
loss_config.hard_example_miner,
classification_weight,
localization_weight)
return (classification_loss, localization_loss,
classification_weight,
localization_weight, hard_example_miner)
random_example_sampler = None
if loss_config.HasField('random_example_sampler'):
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
raise ValueError('RandomExampleSampler should not use non-positive'
'value as positive sample fraction.')
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=loss_config.random_example_sampler.
positive_sample_fraction)
return (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler)
def build_hard_example_miner(config,
......
......@@ -38,7 +38,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
_, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedL2LocalizationLoss))
......@@ -55,7 +55,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
_, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 1.0)
......@@ -74,7 +74,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
_, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 0.1)
......@@ -92,7 +92,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
_, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedIOULocalizationLoss))
......@@ -109,7 +109,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
_, localization_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss))
predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
......@@ -146,7 +146,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss))
......@@ -163,7 +163,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, None)
......@@ -184,7 +184,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, 0.25)
......@@ -203,7 +203,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
......@@ -220,7 +220,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(
isinstance(classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss))
......@@ -239,7 +239,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
......@@ -257,7 +257,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.BootstrappedSigmoidClassificationLoss))
......@@ -275,7 +275,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
classification_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss))
predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
......@@ -312,7 +312,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
_, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertEqual(hard_example_miner, None)
def test_build_hard_example_miner_for_classification_loss(self):
......@@ -331,7 +331,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
_, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._loss_type, 'cls')
......@@ -351,7 +351,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
_, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._loss_type, 'loc')
......@@ -375,7 +375,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner = losses_builder.build(losses_proto)
_, _, _, _, hard_example_miner, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertEqual(hard_example_miner._num_hard_examples, 32)
self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
......@@ -404,7 +404,7 @@ class LossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto)
(classification_loss, localization_loss,
classification_weight, localization_weight,
hard_example_miner) = losses_builder.build(losses_proto)
hard_example_miner, _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
......
......@@ -180,8 +180,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
ssd_config.post_processing)
(classification_loss, localization_loss, classification_weight,
localization_weight,
hard_example_miner) = losses_builder.build(ssd_config.loss)
localization_weight, hard_example_miner,
random_example_sampler) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
......@@ -208,7 +208,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class)
add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
def _build_faster_rcnn_feature_extractor(
......
......@@ -39,6 +39,7 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
Args:
positive_fraction: desired fraction of positive examples (scalar in [0,1])
in the batch.
Raises:
ValueError: if positive_fraction < 0, or positive_fraction > 1
......@@ -53,7 +54,9 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
Args:
indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size.
batch_size: desired batch size. If None, keeps all positive samples and
randomly selects negative samples so that the positive sample fraction
matches self._positive_fraction.
labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples.
......@@ -83,9 +86,19 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
negative_idx = tf.logical_and(negative_idx, indicator)
# Sample positive and negative samples separately
max_num_pos = int(self._positive_fraction * batch_size)
if batch_size is None:
max_num_pos = tf.reduce_sum(tf.to_int32(positive_idx))
else:
max_num_pos = int(self._positive_fraction * batch_size)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
max_num_neg = batch_size - tf.reduce_sum(tf.cast(sampled_pos_idx, tf.int32))
num_sampled_pos = tf.reduce_sum(tf.cast(sampled_pos_idx, tf.int32))
if batch_size is None:
negative_positive_ratio = (
1 - self._positive_fraction) / self._positive_fraction
max_num_neg = tf.to_int32(
negative_positive_ratio * tf.to_float(num_sampled_pos))
else:
max_num_neg = batch_size - num_sampled_pos
sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg)
sampled_idx = tf.logical_or(sampled_pos_idx, sampled_neg_idx)
......
......@@ -19,9 +19,10 @@ import numpy as np
import tensorflow as tf
from object_detection.core import balanced_positive_negative_sampler
from object_detection.utils import test_case
class BalancedPositiveNegativeSamplerTest(tf.test.TestCase):
class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
def test_subsample_all_examples(self):
numpy_labels = np.random.permutation(300)
......@@ -62,6 +63,28 @@ class BalancedPositiveNegativeSamplerTest(tf.test.TestCase):
self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
numpy_indicator))
def test_subsample_selection_no_batch_size(self):
# Test random sampling when only some examples can be sampled:
# 1000 samples, 6 positives (5 can be sampled).
numpy_labels = np.arange(1000)
numpy_indicator = numpy_labels < 999
indicator = tf.constant(numpy_indicator)
numpy_labels = (numpy_labels - 994) >= 0
labels = tf.constant(numpy_labels)
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler(0.01))
is_sampled = sampler.subsample(indicator, None, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 500)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 5)
self.assertTrue(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 495)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
numpy_indicator))
def test_raises_error_with_incorrect_label_shape(self):
labels = tf.constant([[True, False, False]])
indicator = tf.constant([True, False, True])
......
......@@ -237,7 +237,8 @@ class DetectionModel(object):
groundtruth_classes_list,
groundtruth_masks_list=None,
groundtruth_keypoints_list=None,
groundtruth_weights_list=None):
groundtruth_weights_list=None,
groundtruth_is_crowd_list=None):
"""Provide groundtruth tensors.
Args:
......@@ -260,6 +261,8 @@ class DetectionModel(object):
missing keypoints should be encoded as NaN.
groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
groundtruth_is_crowd_list: A list of 1-D tf.bool tensors of shape
[num_boxes] containing is_crowd annotations
"""
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[
......@@ -273,6 +276,9 @@ class DetectionModel(object):
if groundtruth_keypoints_list:
self._groundtruth_lists[
fields.BoxListFields.keypoints] = groundtruth_keypoints_list
if groundtruth_is_crowd_list:
self._groundtruth_lists[
fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list
@abstractmethod
def restore_map(self, fine_tune_checkpoint_type='detection'):
......
......@@ -132,6 +132,7 @@ class BoxListFields(object):
boundaries: boundaries per bounding box.
keypoints: keypoints per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box.
is_crowd: is_crowd annotation per bounding box.
"""
boxes = 'boxes'
classes = 'classes'
......@@ -142,6 +143,7 @@ class BoxListFields(object):
boundaries = 'boundaries'
keypoints = 'keypoints'
keypoint_heatmaps = 'keypoint_heatmaps'
is_crowd = 'is_crowd'
class TfExampleFields(object):
......
......@@ -49,6 +49,7 @@ import tensorflow as tf
from object_detection import evaluator
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import dataset_util
......@@ -127,8 +128,19 @@ def main(unused_argv):
if FLAGS.run_once:
eval_config.max_evals = 1
evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories,
FLAGS.checkpoint_dir, FLAGS.eval_dir)
graph_rewriter_fn = None
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=False)
evaluator.evaluate(
create_input_dict_fn,
model_fn,
eval_config,
categories,
FLAGS.checkpoint_dir,
FLAGS.eval_dir,
graph_hook_fn=graph_rewriter_fn)
if __name__ == '__main__':
......
......@@ -588,7 +588,8 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
'name': (required) string representing category name e.g., 'cat', 'dog'.
eval_dict: An evaluation dictionary, returned from
result_dict_for_single_example().
include_metrics_per_category: If True, include metrics for each category.
include_metrics_per_category: If True, additionally include per-category
metrics.
Returns:
A dictionary of metric names to tuple of value_op and update_op that can be
......@@ -615,7 +616,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
input_data_fields.groundtruth_classes],
detection_boxes=eval_dict[detection_fields.detection_boxes],
detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes]))
detection_classes=eval_dict[detection_fields.detection_classes],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd)))
elif metric == 'coco_mask_metrics':
coco_mask_evaluator = coco_evaluation.CocoMaskEvaluator(
categories, include_metrics_per_category=include_metrics_per_category)
......@@ -629,7 +632,9 @@ def get_eval_metric_ops_for_evaluators(evaluation_metrics,
input_data_fields.groundtruth_instance_masks],
detection_scores=eval_dict[detection_fields.detection_scores],
detection_classes=eval_dict[detection_fields.detection_classes],
detection_masks=eval_dict[detection_fields.detection_masks]))
detection_masks=eval_dict[detection_fields.detection_masks],
groundtruth_is_crowd=eval_dict.get(
input_data_fields.groundtruth_is_crowd),))
else:
raise ValueError('The only evaluation metrics supported are '
'"coco_detection_metrics" and "coco_mask_metrics". '
......
......@@ -197,6 +197,9 @@ def _add_output_tensor_nodes(postprocessed_tensors,
containing scores for the detected boxes.
* detection_classes: float32 tensor of shape [batch_size, num_boxes]
containing class predictions for the detected boxes.
* detection_keypoints: (Optional) float32 tensor of shape
[batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
detection box.
* detection_masks: (Optional) float32 tensor of shape
[batch_size, num_boxes, mask_height, mask_width] containing masks for each
detection box.
......@@ -220,6 +223,7 @@ def _add_output_tensor_nodes(postprocessed_tensors,
scores = postprocessed_tensors.get(detection_fields.detection_scores)
classes = postprocessed_tensors.get(
detection_fields.detection_classes) + label_id_offset
keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
masks = postprocessed_tensors.get(detection_fields.detection_masks)
num_detections = postprocessed_tensors.get(detection_fields.num_detections)
outputs = {}
......@@ -231,6 +235,9 @@ def _add_output_tensor_nodes(postprocessed_tensors,
classes, name=detection_fields.detection_classes)
outputs[detection_fields.num_detections] = tf.identity(
num_detections, name=detection_fields.num_detections)
if keypoints is not None:
outputs[detection_fields.detection_keypoints] = tf.identity(
keypoints, name=detection_fields.detection_keypoints)
if masks is not None:
outputs[detection_fields.detection_masks] = tf.identity(
masks, name=detection_fields.detection_masks)
......
......@@ -71,6 +71,7 @@ Some remarks on frozen inference graphs:
| ------------ | :--------------: | :--------------: | :-------------: |
| [ssd_mobilenet_v1_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz) | 30 | 21 | Boxes |
| [ssd_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz) | 31 | 22 | Boxes |
| [ssdlite_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz) | 27 | 22 | Boxes |
| [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz) | 42 | 24 | Boxes |
| [faster_rcnn_inception_v2_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz) | 58 | 28 | Boxes |
| [faster_rcnn_resnet50_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz) | 89 | 30 | Boxes |
......
......@@ -71,11 +71,10 @@ def transform_input_data(tensor_dict,
model_preprocess_fn: model's preprocess function to apply on image tensor.
This function must take in a 4-D float tensor and return a 4-D preprocess
float tensor and a tensor containing the true image shape.
image_resizer_fn: image resizer function to apply on original image (if
`retain_original_image` is True) and groundtruth instance masks. This
function must take a 3-D float tensor of an image and a 3-D tensor of
instance masks and return a resized version of these along with the true
shapes.
image_resizer_fn: image resizer function to apply on groundtruth instance
`masks. This function must take a 3-D float tensor of an image and a 3-D
tensor of instance masks and return a resized version of these along with
the true shapes.
num_classes: number of max classes to one-hot (or k-hot) encode the class
labels.
data_augmentation_fn: (optional) data augmentation function to apply on
......@@ -90,10 +89,8 @@ def transform_input_data(tensor_dict,
after applying all the transformations.
"""
if retain_original_image:
original_image_resized, _ = image_resizer_fn(
tensor_dict[fields.InputDataFields.image])
tensor_dict[fields.InputDataFields.original_image] = tf.cast(
original_image_resized, tf.uint8)
tensor_dict[fields.InputDataFields.image], tf.uint8)
# Apply data augmentation ops.
if data_augmentation_fn is not None:
......@@ -350,7 +347,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
TypeError: if the `eval_config`, `eval_input_config` or `model_config`
are not of the correct type.
"""
del params
params = params or {}
if not isinstance(eval_config, eval_pb2.EvalConfig):
raise TypeError('For eval mode, the `eval_config` must be a '
'train_pb2.EvalConfig.')
......@@ -375,7 +372,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
eval_input_config,
transform_input_data_fn=transform_data_fn,
batch_size=1,
batch_size=params.get('batch_size', 1),
num_classes=config_util.get_number_of_classes(model_config),
spatial_image_shape=config_util.get_spatial_image_size(
image_resizer_config))
......
......@@ -482,7 +482,7 @@ class DataTransformationFnTest(tf.test.TestCase):
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].dtype, tf.uint8)
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].shape, [8, 8, 3])
fields.InputDataFields.original_image].shape, [4, 4, 3])
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])
......
......@@ -139,7 +139,8 @@ class SSDMetaArch(model.DetectionModel):
normalize_loc_loss_by_codesize=False,
freeze_batchnorm=False,
inplace_batchnorm_update=False,
add_background_class=True):
add_background_class=True,
random_example_sampler=None):
"""SSDMetaArch Constructor.
TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
......@@ -198,6 +199,12 @@ class SSDMetaArch(model.DetectionModel):
one-hot encodings of groundtruth labels. Set to false if using
groundtruth labels with an explicit background class or using multiclass
scores instead of truth in the case of distillation.
random_example_sampler: a BalancedPositiveNegativeSampler object that can
perform random example sampling when computing loss. If None, random
sampling process is skipped. Note that random example sampler and hard
example miner can both be applied to the model. In that case, random
sampler will take effect first and hard example miner can only process
the random sampled examples.
"""
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training
......@@ -240,6 +247,8 @@ class SSDMetaArch(model.DetectionModel):
self._normalize_loss_by_num_matches = normalize_loss_by_num_matches
self._normalize_loc_loss_by_codesize = normalize_loc_loss_by_codesize
self._hard_example_miner = hard_example_miner
self._random_example_sampler = random_example_sampler
self._parallel_iterations = 16
self._image_resizer_fn = image_resizer_fn
self._non_max_suppression_fn = non_max_suppression_fn
......@@ -543,6 +552,20 @@ class SSDMetaArch(model.DetectionModel):
if self._add_summaries:
self._summarize_target_assignment(
self.groundtruth_lists(fields.BoxListFields.boxes), match_list)
if self._random_example_sampler:
batch_sampled_indicator = tf.to_float(
shape_utils.static_or_dynamic_map_fn(
self._minibatch_subsample_fn,
[batch_cls_targets, batch_cls_weights],
dtype=tf.bool,
parallel_iterations=self._parallel_iterations,
back_prop=True))
batch_reg_weights = tf.multiply(batch_sampled_indicator,
batch_reg_weights)
batch_cls_weights = tf.multiply(batch_sampled_indicator,
batch_cls_weights)
location_losses = self._localization_loss(
prediction_dict['box_encodings'],
batch_reg_targets,
......@@ -593,6 +616,32 @@ class SSDMetaArch(model.DetectionModel):
}
return loss_dict
def _minibatch_subsample_fn(self, inputs):
"""Randomly samples anchors for one image.
Args:
inputs: a list of 2 inputs. First one is a tensor of shape [num_anchors,
num_classes] indicating targets assigned to each anchor. Second one
is a tensor of shape [num_anchors] indicating the class weight of each
anchor.
Returns:
batch_sampled_indicator: bool tensor of shape [num_anchors] indicating
whether the anchor should be selected for loss computation.
"""
cls_targets, cls_weights = inputs
if self._add_background_class:
# Set background_class bits to 0 so that the positives_indicator
# computation would not consider background class.
background_class = tf.zeros_like(tf.slice(cls_targets, [0, 0], [-1, 1]))
regular_class = tf.slice(cls_targets, [0, 1], [-1, -1])
cls_targets = tf.concat([background_class, regular_class], 1)
positives_indicator = tf.reduce_sum(cls_targets, axis=1)
return self._random_example_sampler.subsample(
tf.cast(cls_weights, tf.bool),
batch_size=None,
labels=tf.cast(positives_indicator, tf.bool))
def _summarize_anchor_classification_loss(self, class_ids, cls_losses):
positive_indices = tf.where(tf.greater(class_ids, 0))
positive_anchor_cls_loss = tf.squeeze(
......@@ -790,8 +839,8 @@ class SSDMetaArch(model.DetectionModel):
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
load_all_detection_checkpoint_vars: whether to load all variables (when
`from_detection_checkpoint` is True). If False, only variables within
the appropriate scopes are included. Default False.
`fine_tune_checkpoint_type='detection'`). If False, only variables
within the appropriate scopes are included. Default False.
Returns:
A dict mapping variable names (to load from a checkpoint) to variables in
......
......@@ -19,6 +19,7 @@ import numpy as np
import tensorflow as tf
from object_detection.core import anchor_generator
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import box_list
from object_detection.core import losses
from object_detection.core import post_processing
......@@ -83,7 +84,8 @@ class SsdMetaArchTest(test_case.TestCase):
def _create_model(self,
apply_hard_mining=True,
normalize_loc_loss_by_codesize=False,
add_background_class=True):
add_background_class=True,
random_example_sampling=False):
is_training = False
num_classes = 1
mock_anchor_generator = MockAnchorGenerator2x2()
......@@ -117,6 +119,11 @@ class SsdMetaArchTest(test_case.TestCase):
num_hard_examples=None,
iou_threshold=1.0)
random_example_sampler = None
if random_example_sampling:
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=0.5)
code_size = 4
model = ssd_meta_arch.SSDMetaArch(
is_training,
......@@ -141,7 +148,8 @@ class SsdMetaArchTest(test_case.TestCase):
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=False,
inplace_batchnorm_update=False,
add_background_class=add_background_class)
add_background_class=add_background_class,
random_example_sampler=random_example_sampler)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size
def test_preprocess_preserves_shapes_with_dynamic_input_image(self):
......@@ -493,6 +501,47 @@ class SsdMetaArchTest(test_case.TestCase):
self.assertIsInstance(var_map, dict)
self.assertIn('another_variable', var_map)
def test_loss_results_are_correct_with_random_example_sampling(self):
with tf.Graph().as_default():
_, num_classes, num_anchors, _ = self._create_model(
random_example_sampling=True)
print num_classes, num_anchors
def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2):
groundtruth_boxes_list = [groundtruth_boxes1, groundtruth_boxes2]
groundtruth_classes_list = [groundtruth_classes1, groundtruth_classes2]
model, _, _, _ = self._create_model(random_example_sampling=True)
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
prediction_dict = model.predict(
preprocessed_tensor, true_image_shapes=None)
loss_dict = model.loss(prediction_dict, true_image_shapes=None)
return (_get_value_for_matching_key(loss_dict, 'Loss/localization_loss'),
_get_value_for_matching_key(loss_dict,
'Loss/classification_loss'))
batch_size = 2
preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32)
groundtruth_boxes1 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_boxes2 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_classes1 = np.array([[1]], dtype=np.float32)
groundtruth_classes2 = np.array([[1]], dtype=np.float32)
expected_localization_loss = 0.0
# Among 4 anchors (1 positive, 3 negative) in this test, only 2 anchors are
# selected (1 positive, 1 negative) since random sampler will adjust number
# of negative examples to make sure positive example fraction in the batch
# is 0.5.
expected_classification_loss = (
batch_size * 2 * (num_classes + 1) * np.log(2.0))
(localization_loss, classification_loss) = self.execute_cpu(
graph_fn, [
preprocessed_input, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2
])
self.assertAllClose(localization_loss, expected_localization_loss)
self.assertAllClose(classification_loss, expected_classification_loss)
if __name__ == '__main__':
tf.test.main()
......@@ -202,8 +202,10 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
return box_metrics
def get_estimator_eval_metric_ops(self, image_id, groundtruth_boxes,
groundtruth_classes, detection_boxes,
groundtruth_classes,
detection_boxes,
detection_scores, detection_classes,
groundtruth_is_crowd=None,
num_gt_boxes_per_image=None,
num_det_boxes_per_image=None):
"""Returns a dictionary of eval metric ops to use with `tf.EstimatorSpec`.
......@@ -230,6 +232,9 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
detection scores for the boxes.
detection_classes: int32 tensor of shape [batch, num_boxes] containing
1-indexed detection classes for the boxes.
groundtruth_is_crowd: bool tensor of shape [batch, num_boxes] containing
is_crowd annotations. This field is optional, and if not passed, then
all boxes are treated as *not* is_crowd.
num_gt_boxes_per_image: int32 tensor of shape [batch] containing the
number of groundtruth boxes per image. If None, will assume no padding
in groundtruth tensors.
......@@ -247,6 +252,7 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
image_id_batched,
groundtruth_boxes_batched,
groundtruth_classes_batched,
groundtruth_is_crowd_batched,
num_gt_boxes_per_image,
detection_boxes_batched,
detection_scores_batched,
......@@ -254,27 +260,32 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
num_det_boxes_per_image):
"""Update operation for adding batch of images to Coco evaluator."""
for (image_id, gt_box, gt_class, num_gt_box, det_box, det_score,
det_class, num_det_box) in zip(
for (image_id, gt_box, gt_class, gt_is_crowd, num_gt_box, det_box,
det_score, det_class, num_det_box) in zip(
image_id_batched, groundtruth_boxes_batched,
groundtruth_classes_batched, num_gt_boxes_per_image,
groundtruth_classes_batched, groundtruth_is_crowd_batched,
num_gt_boxes_per_image,
detection_boxes_batched, detection_scores_batched,
detection_classes_batched, num_det_boxes_per_image):
self.add_single_ground_truth_image_info(
image_id,
{'groundtruth_boxes': gt_box[:num_gt_box],
'groundtruth_classes': gt_class[:num_gt_box]})
'groundtruth_classes': gt_class[:num_gt_box],
'groundtruth_is_crowd': gt_is_crowd[:num_gt_box]})
self.add_single_detected_image_info(
image_id,
{'detection_boxes': det_box[:num_det_box],
'detection_scores': det_score[:num_det_box],
'detection_classes': det_class[:num_det_box]})
if groundtruth_is_crowd is None:
groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool)
if not image_id.shape.as_list():
# Apply a batch dimension to all tensors.
image_id = tf.expand_dims(image_id, 0)
groundtruth_boxes = tf.expand_dims(groundtruth_boxes, 0)
groundtruth_classes = tf.expand_dims(groundtruth_classes, 0)
groundtruth_is_crowd = tf.expand_dims(groundtruth_is_crowd, 0)
detection_boxes = tf.expand_dims(detection_boxes, 0)
detection_scores = tf.expand_dims(detection_scores, 0)
detection_classes = tf.expand_dims(detection_classes, 0)
......@@ -301,6 +312,7 @@ class CocoDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
update_op = tf.py_func(update_op, [image_id,
groundtruth_boxes,
groundtruth_classes,
groundtruth_is_crowd,
num_gt_boxes_per_image,
detection_boxes,
detection_scores,
......@@ -545,7 +557,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
groundtruth_classes,
groundtruth_instance_masks,
detection_scores, detection_classes,
detection_masks):
detection_masks, groundtruth_is_crowd=None):
"""Returns a dictionary of eval metric ops to use with `tf.EstimatorSpec`.
Note that once value_op is called, the detections and groundtruth added via
......@@ -568,6 +580,9 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
detection_masks: uint8 tensor array of shape
[num_boxes, image_height, image_width] containing instance masks
corresponding to the boxes. The elements of the array must be in {0, 1}.
groundtruth_is_crowd: bool tensor of shape [batch, num_boxes] containing
is_crowd annotations. This field is optional, and if not passed, then
all boxes are treated as *not* is_crowd.
Returns:
a dictionary of metric names to tuple of value_op and update_op that can
......@@ -580,6 +595,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
groundtruth_boxes,
groundtruth_classes,
groundtruth_instance_masks,
groundtruth_is_crowd,
detection_scores,
detection_classes,
detection_masks):
......@@ -587,17 +603,21 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
image_id,
{'groundtruth_boxes': groundtruth_boxes,
'groundtruth_classes': groundtruth_classes,
'groundtruth_instance_masks': groundtruth_instance_masks})
'groundtruth_instance_masks': groundtruth_instance_masks,
'groundtruth_is_crowd': groundtruth_is_crowd})
self.add_single_detected_image_info(
image_id,
{'detection_scores': detection_scores,
'detection_classes': detection_classes,
'detection_masks': detection_masks})
if groundtruth_is_crowd is None:
groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool)
update_op = tf.py_func(update_op, [image_id,
groundtruth_boxes,
groundtruth_classes,
groundtruth_instance_masks,
groundtruth_is_crowd,
detection_scores,
detection_classes,
detection_masks], [])
......
......@@ -492,8 +492,8 @@ class CocoEvaluationPyFuncTest(tf.test.TestCase):
detection_boxes,
detection_scores,
detection_classes,
num_gt_boxes_per_image,
num_det_boxes_per_image)
num_gt_boxes_per_image=num_gt_boxes_per_image,
num_det_boxes_per_image=num_det_boxes_per_image)
_, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP']
......
......@@ -48,8 +48,8 @@ MODEL_BUILD_UTIL_MAP = {
}
def _get_groundtruth_data(detection_model, class_agnostic):
"""Extracts groundtruth data from detection_model.
def _prepare_groundtruth_for_eval(detection_model, class_agnostic):
"""Extracts groundtruth data from detection_model and prepares it for eval.
Args:
detection_model: A `DetectionModel` object.
......@@ -63,6 +63,8 @@ def _get_groundtruth_data(detection_model, class_agnostic):
'groundtruth_classes': [num_boxes] int64 tensor of 1-indexed classes.
'groundtruth_masks': 3D float32 tensor of instance masks (if provided in
groundtruth)
'groundtruth_is_crowd': [num_boxes] bool tensor indicating is_crowd
annotations (if provided in groundtruth).
class_agnostic: Boolean indicating whether detections are class agnostic.
"""
input_data_fields = fields.InputDataFields()
......@@ -86,6 +88,9 @@ def _get_groundtruth_data(detection_model, class_agnostic):
if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
groundtruth[input_data_fields.groundtruth_instance_masks] = (
detection_model.groundtruth_lists(fields.BoxListFields.masks)[0])
if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
groundtruth[input_data_fields.groundtruth_is_crowd] = (
detection_model.groundtruth_lists(fields.BoxListFields.is_crowd)[0])
return groundtruth
......@@ -224,13 +229,16 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
if fields.InputDataFields.groundtruth_is_crowd in labels:
gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
detection_model.provide_groundtruth(
groundtruth_boxes_list=gt_boxes_list,
groundtruth_classes_list=gt_classes_list,
groundtruth_masks_list=gt_masks_list,
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_weights_list=labels[
fields.InputDataFields.groundtruth_weights])
fields.InputDataFields.groundtruth_weights],
groundtruth_is_crowd_list=gt_is_crowd_list)
preprocessed_images = features[fields.InputDataFields.image]
prediction_dict = detection_model.predict(
......@@ -328,7 +336,8 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
if mode == tf.estimator.ModeKeys.EVAL:
class_agnostic = (fields.DetectionResultFields.detection_classes
not in detections)
groundtruth = _get_groundtruth_data(detection_model, class_agnostic)
groundtruth = _prepare_groundtruth_for_eval(
detection_model, class_agnostic)
use_original_images = fields.InputDataFields.original_image in features
eval_images = (
features[fields.InputDataFields.original_image] if use_original_images
......@@ -339,7 +348,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
detections,
groundtruth,
class_agnostic=class_agnostic,
scale_to_absolute=False)
scale_to_absolute=True)
if class_agnostic:
category_index = label_map_util.create_class_agnostic_category_index()
......@@ -360,8 +369,10 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
if not eval_metrics:
eval_metrics = ['coco_detection_metrics']
eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
eval_metrics, category_index.values(), eval_dict,
include_metrics_per_category=False)
eval_metrics,
category_index.values(),
eval_dict,
include_metrics_per_category=eval_config.include_metrics_per_category)
for loss_key, loss_tensor in iter(losses_dict.items()):
eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
for var in optimizer_summary_vars:
......@@ -528,6 +539,7 @@ def create_train_and_eval_specs(train_input_fn,
train_steps,
eval_steps,
eval_on_train_data=False,
eval_on_train_steps=None,
final_exporter_name='Servo',
eval_spec_name='eval'):
"""Creates a `TrainSpec` and `EvalSpec`s.
......@@ -542,6 +554,8 @@ def create_train_and_eval_specs(train_input_fn,
eval_steps: Number of eval steps.
eval_on_train_data: Whether to evaluate model on training data. Default is
False.
eval_on_train_steps: Number of eval steps for training data. If not given,
uses eval_steps.
final_exporter_name: String name given to `FinalExporter`.
eval_spec_name: String name given to main `EvalSpec`.
......@@ -569,7 +583,7 @@ def create_train_and_eval_specs(train_input_fn,
eval_specs.append(
tf.estimator.EvalSpec(
name='eval_on_train', input_fn=eval_on_train_input_fn,
steps=eval_steps))
steps=eval_on_train_steps or eval_steps))
return train_spec, eval_specs
......
......@@ -253,6 +253,7 @@ class ModelLibTest(tf.test.TestCase):
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
train_steps = 20
eval_steps = 10
eval_on_train_steps = 15
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config,
hparams,
......@@ -274,6 +275,7 @@ class ModelLibTest(tf.test.TestCase):
train_steps,
eval_steps,
eval_on_train_data=True,
eval_on_train_steps=eval_on_train_steps,
final_exporter_name='exporter',
eval_spec_name='holdout')
self.assertEqual(train_steps, train_spec.max_steps)
......@@ -281,7 +283,7 @@ class ModelLibTest(tf.test.TestCase):
self.assertEqual(eval_steps, eval_specs[0].steps)
self.assertEqual('holdout', eval_specs[0].name)
self.assertEqual('exporter', eval_specs[0].exporters[0].name)
self.assertEqual(eval_steps, eval_specs[1].steps)
self.assertEqual(eval_on_train_steps, eval_specs[1].steps)
self.assertEqual('eval_on_train', eval_specs[1].name)
def test_experiment(self):
......
......@@ -185,8 +185,9 @@ def fpn_top_down_feature_maps(image_features, depth, scope=None):
See https://arxiv.org/abs/1612.03144 for details.
Args:
image_features: list of image feature tensors. Spatial resolutions of
succesive tensors must reduce exactly by a factor of 2.
image_features: list of tuples of (tensor_name, image_feature_tensor).
Spatial resolutions of succesive tensors must reduce exactly by a factor
of 2.
depth: depth of output feature maps.
scope: A scope name to wrap this op under.
......@@ -194,32 +195,31 @@ def fpn_top_down_feature_maps(image_features, depth, scope=None):
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
"""
with tf.variable_scope(
scope, 'top_down', image_features):
with tf.name_scope(scope, 'top_down'):
num_levels = len(image_features)
output_feature_maps_list = []
output_feature_map_keys = []
with slim.arg_scope(
[slim.conv2d],
activation_fn=None, normalizer_fn=None, padding='SAME', stride=1):
[slim.conv2d], padding='SAME', stride=1):
top_down = slim.conv2d(
image_features[-1],
depth, [1, 1], scope='projection_%d' % num_levels)
image_features[-1][1],
depth, [1, 1], activation_fn=None, normalizer_fn=None,
scope='projection_%d' % num_levels)
output_feature_maps_list.append(top_down)
output_feature_map_keys.append(
'top_down_feature_map_%d' % (num_levels - 1))
'top_down_%s' % image_features[-1][0])
for level in reversed(range(num_levels - 1)):
top_down = ops.nearest_neighbor_upsampling(top_down, 2)
residual = slim.conv2d(
image_features[level], depth, [1, 1],
image_features[level][1], depth, [1, 1],
activation_fn=None, normalizer_fn=None,
scope='projection_%d' % (level + 1))
top_down = 0.5 * top_down + 0.5 * residual
top_down += residual
output_feature_maps_list.append(slim.conv2d(
top_down,
depth, [3, 3],
activation_fn=None,
scope='smoothing_%d' % (level + 1)))
output_feature_map_keys.append('top_down_feature_map_%d' % level)
output_feature_map_keys.append('top_down_%s' % image_features[level][0])
return collections.OrderedDict(
reversed(zip(output_feature_map_keys, output_feature_maps_list)))
......@@ -138,19 +138,19 @@ class FPNFeatureMapGeneratorTest(tf.test.TestCase):
def test_get_expected_feature_map_shapes(self):
image_features = [
tf.random_uniform([4, 8, 8, 256], dtype=tf.float32),
tf.random_uniform([4, 4, 4, 256], dtype=tf.float32),
tf.random_uniform([4, 2, 2, 256], dtype=tf.float32),
tf.random_uniform([4, 1, 1, 256], dtype=tf.float32),
('block2', tf.random_uniform([4, 8, 8, 256], dtype=tf.float32)),
('block3', tf.random_uniform([4, 4, 4, 256], dtype=tf.float32)),
('block4', tf.random_uniform([4, 2, 2, 256], dtype=tf.float32)),
('block5', tf.random_uniform([4, 1, 1, 256], dtype=tf.float32))
]
feature_maps = feature_map_generators.fpn_top_down_feature_maps(
image_features=image_features, depth=128)
expected_feature_map_shapes = {
'top_down_feature_map_0': (4, 8, 8, 128),
'top_down_feature_map_1': (4, 4, 4, 128),
'top_down_feature_map_2': (4, 2, 2, 128),
'top_down_feature_map_3': (4, 1, 1, 128)
'top_down_block2': (4, 8, 8, 128),
'top_down_block3': (4, 4, 4, 128),
'top_down_block4': (4, 2, 2, 128),
'top_down_block5': (4, 1, 1, 128)
}
init_op = tf.global_variables_initializer()
......
......@@ -147,27 +147,30 @@ class _SSDResnetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
output_stride=None,
store_non_strided_activations=True,
scope=scope)
image_features = self._filter_features(image_features)
last_feature_map = image_features['block4']
with tf.variable_scope(self._fpn_scope_name, reuse=self._reuse_weights):
image_features = self._filter_features(image_features)
with slim.arg_scope(self._conv_hyperparams_fn()):
for i in range(5, 7):
last_feature_map = slim.conv2d(
last_feature_map,
num_outputs=256,
kernel_size=[3, 3],
stride=2,
padding='SAME',
scope='block{}'.format(i))
image_features['bottomup_{}'.format(i)] = last_feature_map
feature_maps = feature_map_generators.fpn_top_down_feature_maps(
[
image_features[key] for key in
['block2', 'block3', 'block4', 'bottomup_5', 'bottomup_6']
],
depth=256,
scope='top_down_features')
return feature_maps.values()
with tf.variable_scope(self._fpn_scope_name,
reuse=self._reuse_weights):
fpn_features = feature_map_generators.fpn_top_down_feature_maps(
[(key, image_features[key])
for key in ['block2', 'block3', 'block4']],
depth=256)
last_feature_map = fpn_features['top_down_block4']
coarse_features = {}
for i in range(5, 7):
last_feature_map = slim.conv2d(
last_feature_map,
num_outputs=256,
kernel_size=[3, 3],
stride=2,
padding='SAME',
scope='bottom_up_block{}'.format(i))
coarse_features['bottom_up_block{}'.format(i)] = last_feature_map
return [fpn_features['top_down_block2'],
fpn_features['top_down_block3'],
fpn_features['top_down_block4'],
coarse_features['bottom_up_block5'],
coarse_features['bottom_up_block6']]
class SSDResnet50V1FpnFeatureExtractor(_SSDResnetV1FpnFeatureExtractor):
......
......@@ -72,4 +72,7 @@ message EvalConfig {
// Whether to retain original images (i.e. not pre-processed) in the tensor
// dictionary, so that they can be displayed in Tensorboard.
optional bool retain_original_images = 23 [default=true];
// If True, additionally include per-category metrics.
optional bool include_metrics_per_category = 24 [default=false];
}
syntax = "proto2";
package object_detection.protos;
// Message to configure graph rewriter for the tf graph.
message GraphRewriter {
optional Quantization quantization = 1;
}
// Message for quantization options. See
// tensorflow/contrib/quantize/python/quantize.py for details.
message Quantization {
// Number of steps to delay before quantization takes effect during training.
optional int32 delay = 1 [default = 500000];
// Number of bits to use for quantizing weights.
// Only 8 bit is supported for now.
optional int32 weight_bits = 2 [default = 8];
// Number of bits to use for quantizing activations.
// Only 8 bit is supported for now.
optional int32 activation_bits = 3 [default = 8];
}
......@@ -38,6 +38,10 @@ message Hyperparams {
// BatchNorm hyperparameters. If this parameter is NOT set then BatchNorm is
// not applied!
optional BatchNorm batch_norm = 5;
// Whether depthwise convolutions should be regularized. If this parameter is
// NOT set then the conv hyperparams will default to the parent scope.
optional bool regularize_depthwise = 6 [default = false];
}
// Proto with one-of field for regularizers.
......
......@@ -20,6 +20,9 @@ message Loss {
// Localization loss weight.
optional float localization_weight = 5 [default=1.0];
// If not left to default, applies random example sampling.
optional RandomExampleSampler random_example_sampler = 6;
}
// Configuration for bounding box localization loss function.
......@@ -121,7 +124,7 @@ message BootstrappedSigmoidClassificationLoss {
optional bool anchorwise_output = 3 [default=false];
}
// Configuation for hard example miner.
// Configuration for hard example miner.
message HardExampleMiner {
// Maximum number of hard examples to be selected per image (prior to
// enforcing max negative to positive ratio constraint). If set to 0,
......@@ -152,3 +155,10 @@ message HardExampleMiner {
// detection per image.
optional int32 min_negatives_per_image = 5 [default=0];
}
// Configuration for random example sampler.
message RandomExampleSampler {
// The desired fraction of positive samples in batch when applying random
// example sampling.
optional float positive_sample_fraction = 1 [default = 0.01];
}
......@@ -5,4 +5,6 @@ package object_detection.protos;
// Configuration proto for MeanStddevBoxCoder. See
// box_coders/mean_stddev_box_coder.py for details.
message MeanStddevBoxCoder {
// The standard deviation used to encode and decode boxes.
optional float stddev = 1 [default=0.01];
}
......@@ -3,6 +3,7 @@ syntax = "proto2";
package object_detection.protos;
import "object_detection/protos/eval.proto";
import "object_detection/protos/graph_rewriter.proto";
import "object_detection/protos/input_reader.proto";
import "object_detection/protos/model.proto";
import "object_detection/protos/train.proto";
......@@ -15,5 +16,6 @@ message TrainEvalPipelineConfig {
optional InputReader train_input_reader = 3;
optional EvalConfig eval_config = 4;
optional InputReader eval_input_reader = 5;
optional GraphRewriter graph_rewriter = 6;
extensions 1000 to max;
}
......@@ -53,8 +53,7 @@ model {
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
use_depthwise: true
kernel_size: 1
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
......@@ -84,7 +83,6 @@ model {
type: 'ssd_mobilenet_v2'
min_depth: 16
depth_multiplier: 1.0
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
......
# SSDLite with Mobilenet v1 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
ssd {
num_classes: 90
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
use_depthwise: true
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_mobilenet_v1'
min_depth: 16
depth_multiplier: 1.0
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 200000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
}
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
eval_config: {
num_examples: 8000
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}
eval_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
shuffle: false
num_readers: 1
}
# SSDLite with Mobilenet v2 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
ssd {
num_classes: 90
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
use_depthwise: true
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_mobilenet_v2'
min_depth: 16
depth_multiplier: 1.0
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 3
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"
fine_tune_checkpoint_type: "detection"
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 200000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
}
train_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
}
eval_config: {
num_examples: 8000
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}
eval_input_reader: {
tf_record_input_reader {
input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record"
}
label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
shuffle: false
num_readers: 1
}
\ No newline at end of file
......@@ -48,6 +48,7 @@ import tensorflow as tf
from object_detection import trainer
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import dataset_util
......@@ -158,9 +159,25 @@ def main(_):
is_chief = (task_info.type == 'master')
master = server.target
trainer.train(create_input_dict_fn, model_fn, train_config, master, task,
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks,
worker_job_name, is_chief, FLAGS.train_dir)
graph_rewriter_fn = None
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=True)
trainer.train(
create_input_dict_fn,
model_fn,
train_config,
master,
task,
FLAGS.num_clones,
worker_replicas,
FLAGS.clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
FLAGS.train_dir,
graph_hook_fn=graph_rewriter_fn)
if __name__ == '__main__':
......
......@@ -231,10 +231,10 @@ def train(create_tensor_dict_fn,
worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica.
train_dir: Directory to write checkpoints and training summaries to.
graph_hook_fn: Optional function that is called after the training graph is
completely built. This is helpful to perform additional changes to the
training graph such as optimizing batchnorm. The function should modify
the default graph.
graph_hook_fn: Optional function that is called after the inference graph is
built (before optimization). This is helpful to perform additional changes
to the training graph such as adding FakeQuant ops. The function should
modify the default graph.
"""
detection_model = create_model_fn()
......@@ -275,6 +275,10 @@ def train(create_tensor_dict_fn,
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
first_clone_scope = clones[0].scope
if graph_hook_fn:
with tf.device(deploy_config.variables_device()):
graph_hook_fn()
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
......@@ -328,10 +332,6 @@ def train(create_tensor_dict_fn,
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
if graph_hook_fn:
with tf.device(deploy_config.variables_device()):
graph_hook_fn()
# Add summaries.
for model_var in slim.get_model_variables():
global_summaries.add(tf.summary.histogram('ModelVars/' +
......
......@@ -22,6 +22,7 @@ from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from object_detection.protos import eval_pb2
from object_detection.protos import graph_rewriter_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
......@@ -111,9 +112,27 @@ def create_configs_from_pipeline_proto(pipeline_config):
configs["train_input_config"] = pipeline_config.train_input_reader
configs["eval_config"] = pipeline_config.eval_config
configs["eval_input_config"] = pipeline_config.eval_input_reader
if pipeline_config.HasField("graph_rewriter"):
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
return configs
def get_graph_rewriter_config_from_file(graph_rewriter_config_file):
"""Parses config for graph rewriter.
Args:
graph_rewriter_config_file: file path to the graph rewriter config.
Returns:
graph_rewriter_pb2.GraphRewriter proto
"""
graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
with tf.gfile.GFile(graph_rewriter_config_file, "r") as f:
text_format.Merge(f.read(), graph_rewriter_config)
return graph_rewriter_config
def create_pipeline_proto_from_configs(configs):
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
......@@ -132,6 +151,8 @@ def create_pipeline_proto_from_configs(configs):
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
pipeline_config.eval_config.CopyFrom(configs["eval_config"])
pipeline_config.eval_input_reader.CopyFrom(configs["eval_input_config"])
if "graph_rewriter_config" in configs:
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"])
return pipeline_config
......@@ -157,7 +178,8 @@ def get_configs_from_multiple_files(model_config_path="",
train_config_path="",
train_input_config_path="",
eval_config_path="",
eval_input_config_path=""):
eval_input_config_path="",
graph_rewriter_config_path=""):
"""Reads training configuration from multiple config files.
Args:
......@@ -166,6 +188,7 @@ def get_configs_from_multiple_files(model_config_path="",
train_input_config_path: Path to input_reader_pb2.InputReader.
eval_config_path: Path to eval_pb2.EvalConfig.
eval_input_config_path: Path to input_reader_pb2.InputReader.
graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
......@@ -203,6 +226,10 @@ def get_configs_from_multiple_files(model_config_path="",
text_format.Merge(f.read(), eval_input_config)
configs["eval_input_config"] = eval_input_config
if graph_rewriter_config_path:
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
graph_rewriter_config_path)
return configs
......
......@@ -132,7 +132,7 @@ def read_dataset(file_read_func, decode_func, input_files, config):
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func, cycle_length=config.num_readers,
block_length=config.read_block_length, sloppy=True))
block_length=config.read_block_length, sloppy=config.shuffle))
if config.shuffle:
records_dataset.shuffle(config.shuffle_buffer_size)
tensor_dataset = records_dataset.map(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册