提交 f68a262d 编写于 作者: A A. Unique TensorFlower 提交者: TF Object Detection Team

CenterNet + track embedding joint training pipeline.

PiperOrigin-RevId: 325470519
上级 9cbce60e
......@@ -59,7 +59,8 @@ def build(input_reader_config):
num_additional_channels=input_reader_config.num_additional_channels,
num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy,
load_dense_pose=input_reader_config.load_dense_pose)
load_dense_pose=input_reader_config.load_dense_pose,
load_track_id=input_reader_config.load_track_id)
return decoder
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
......
......@@ -918,6 +918,24 @@ def densepose_proto_to_params(densepose_config):
heatmap_bias_init=densepose_config.heatmap_bias_init)
def tracking_proto_to_params(tracking_config):
"""Converts CenterNet.TrackEstimation proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy localization loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the localization loss
# directly.
loss.localization_loss.weighted_l2.CopyFrom(
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(tracking_config.classification_loss)
classification_loss, _, _, _, _, _, _ = losses_builder.build(loss)
return center_net_meta_arch.TrackParams(
num_track_ids=tracking_config.num_track_ids,
reid_embed_size=tracking_config.reid_embed_size,
classification_loss=classification_loss,
num_fc_layers=tracking_config.num_fc_layers,
task_loss_weight=tracking_config.task_loss_weight)
def _build_center_net_model(center_net_config, is_training, add_summaries):
"""Build a CenterNet detection model.
......@@ -975,6 +993,11 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
densepose_params = densepose_proto_to_params(
center_net_config.densepose_estimation_task)
track_params = None
if center_net_config.HasField('track_estimation_task'):
track_params = tracking_proto_to_params(
center_net_config.track_estimation_task)
return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training,
add_summaries=add_summaries,
......@@ -985,7 +1008,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
object_detection_params=object_detection_params,
keypoint_params_dict=keypoint_params_dict,
mask_params=mask_params,
densepose_params=densepose_params)
densepose_params=densepose_params,
track_params=track_params)
def _build_center_net_feature_extractor(
......
......@@ -102,7 +102,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args:
field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints,
keypoint_visibilities, densepose_*}
keypoint_visibilities, densepose_*, track_ids}
fields.InputDataFields.is_annotated.
Returns:
......@@ -123,7 +123,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args:
field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints,
keypoint_visibilities, densepose_*} or
keypoint_visibilities, densepose_*, track_ids} or
fields.InputDataFields.is_annotated.
Returns:
......@@ -303,6 +303,7 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_dp_num_points_list=None,
groundtruth_dp_part_ids_list=None,
groundtruth_dp_surface_coords_list=None,
groundtruth_track_ids_list=None,
groundtruth_weights_list=None,
groundtruth_confidences_list=None,
groundtruth_is_crowd_list=None,
......@@ -342,6 +343,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
shape [num_boxes, max_sampled_points, 4] containing the DensePose
surface coordinates for each sampled point. Note that there may be
padding.
groundtruth_track_ids_list: a list of 1-D tf.int32 tensors of shape
[num_boxes] containing the track IDs of groundtruth objects.
groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
groundtruth_confidences_list: A list of 2-D tf.float32 tensors of shape
......@@ -391,6 +394,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
self._groundtruth_lists[
fields.BoxListFields.densepose_surface_coords] = (
groundtruth_dp_surface_coords_list)
if groundtruth_track_ids_list:
self._groundtruth_lists[
fields.BoxListFields.track_ids] = groundtruth_track_ids_list
if groundtruth_is_crowd_list:
self._groundtruth_lists[
fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list
......
......@@ -46,6 +46,7 @@ class InputDataFields(object):
classes for which an image has been labeled.
groundtruth_boxes: coordinates of the ground truth boxes in the image.
groundtruth_classes: box-level class labels.
groundtruth_track_ids: box-level track ID labels.
groundtruth_confidences: box-level class confidences. The shape should be
the same as the shape of groundtruth_classes.
groundtruth_label_types: box-level label types (e.g. explicit negative).
......@@ -97,6 +98,7 @@ class InputDataFields(object):
groundtruth_labeled_classes = 'groundtruth_labeled_classes'
groundtruth_boxes = 'groundtruth_boxes'
groundtruth_classes = 'groundtruth_classes'
groundtruth_track_ids = 'groundtruth_track_ids'
groundtruth_confidences = 'groundtruth_confidences'
groundtruth_label_types = 'groundtruth_label_types'
groundtruth_is_crowd = 'groundtruth_is_crowd'
......@@ -167,6 +169,7 @@ class DetectionResultFields(object):
detection_boundaries = 'detection_boundaries'
detection_keypoints = 'detection_keypoints'
detection_keypoint_scores = 'detection_keypoint_scores'
detection_embeddings = 'detection_embeddings'
num_detections = 'num_detections'
raw_detection_boxes = 'raw_detection_boxes'
raw_detection_scores = 'raw_detection_scores'
......@@ -208,6 +211,7 @@ class BoxListFields(object):
densepose_surface_coords = 'densepose_surface_coords'
is_crowd = 'is_crowd'
group_of = 'group_of'
track_ids = 'track_ids'
class PredictionFields(object):
......
......@@ -1782,6 +1782,89 @@ class CenterNetDensePoseTargetAssigner(object):
return batch_indices, batch_part_ids, batch_surface_coords, batch_weights
class CenterNetTrackTargetAssigner(object):
"""Wrapper to compute targets for tracking task.
Reference paper: A Simple Baseline for Multi-Object Tracking [1]
[1]: https://arxiv.org/abs/2004.01888
"""
def __init__(self, stride, num_track_ids):
self._stride = stride
self._num_track_ids = num_track_ids
def assign_track_targets(self,
height,
width,
gt_track_ids_list,
gt_boxes_list,
gt_weights_list=None):
"""Computes the track ID targets.
Args:
height: int, height of input to the model. This is used to determine the
height of the output.
width: int, width of the input to the model. This is used to determine the
width of the output.
gt_track_ids_list: A list of 1-D tensors with shape [num_boxes]
corresponding to the track ID of each groundtruth detection box.
gt_boxes_list: A list of float tensors with shape [num_boxes, 4]
representing the groundtruth detection bounding boxes for each sample in
the batch. The coordinates are expected in normalized coordinates.
gt_weights_list: A list of 1-D tensors with shape [num_boxes]
corresponding to the weight of each groundtruth detection box.
Returns:
batch_indices: an integer tensor of shape [batch_size, num_boxes, 3]
holding the indices inside the predicted tensor which should be
penalized. The first column indicates the index along the batch
dimension and the second and third columns indicate the index
along the y and x dimensions respectively.
batch_weights: a float tensor of shape [batch_size, num_boxes] indicating
the weight of each prediction.
track_id_targets: An int32 tensor of size [batch_size, num_boxes,
num_track_ids] containing the one-hot track ID vector of each
groundtruth detection box.
"""
track_id_targets = tf.one_hot(
gt_track_ids_list, depth=self._num_track_ids, axis=-1)
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_boxes_list)
batch_indices = []
batch_weights = []
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
num_boxes = tf.shape(x_center)
# Compute the indices of the box centers. Shape:
# indices: [num_boxes, 2]
(_, indices) = ta_utils.compute_floor_offsets_with_indices(
y_source=y_center, x_source=x_center)
# Assign ones if weights are not provided.
if weights is None:
weights = tf.ones(num_boxes, dtype=tf.float32)
# Shape of [num_boxes, 1] integer tensor filled with current batch index.
batch_index = i * tf.ones_like(indices[:, 0:1], dtype=tf.int32)
batch_indices.append(tf.concat([batch_index, indices], axis=1))
batch_weights.append(weights)
batch_indices = tf.stack(batch_indices, axis=0)
batch_weights = tf.stack(batch_weights, axis=0)
return batch_indices, batch_weights, track_id_targets
def filter_mask_overlap_min_area(masks):
"""If a pixel belongs to 2 instances, remove it from the larger instance."""
......
......@@ -2015,6 +2015,106 @@ class CenterNetDensePoseTargetAssignerTest(test_case.TestCase):
self.assertAllClose(expected_batch_weights, batch_weights)
class CenterNetTrackTargetAssignerTest(test_case.TestCase):
def setUp(self):
super(CenterNetTrackTargetAssignerTest, self).setUp()
self._box_center = [0.0, 0.0, 1.0, 1.0]
self._box_center_small = [0.25, 0.25, 0.75, 0.75]
self._box_lower_left = [0.5, 0.0, 1.0, 0.5]
self._box_center_offset = [0.1, 0.05, 1.0, 1.0]
self._box_odd_coordinates = [0.1625, 0.2125, 0.5625, 0.9625]
def test_assign_track_targets(self):
"""Test the assign_track_targets function."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_lower_left, self._box_center_small]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
track_id_batch = [
tf.constant([0, 1]),
tf.constant([1, 0]),
tf.constant([0, 2]),
]
assigner = targetassigner.CenterNetTrackTargetAssigner(
stride=4, num_track_ids=3)
(batch_indices, batch_weights,
track_targets) = assigner.assign_track_targets(
height=80,
width=80,
gt_track_ids_list=track_id_batch,
gt_boxes_list=box_batch)
return batch_indices, batch_weights, track_targets
indices, weights, track_ids = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (3, 2, 3))
self.assertEqual(track_ids.shape, (3, 2, 3))
self.assertEqual(weights.shape, (3, 2))
np.testing.assert_array_equal(indices,
[[[0, 10, 10], [0, 15, 5]],
[[1, 15, 5], [1, 10, 10]],
[[2, 10, 10], [2, 7, 11]]])
np.testing.assert_array_equal(track_ids,
[[[1, 0, 0], [0, 1, 0]],
[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 0, 1]]])
np.testing.assert_array_equal(weights, [[1, 1], [1, 1], [1, 1]])
def test_assign_track_targets_weights(self):
"""Test the assign_track_targets function with box weights."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_lower_left, self._box_center_small]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
track_id_batch = [
tf.constant([0, 1]),
tf.constant([1, 0]),
tf.constant([0, 2]),
]
weights_batch = [
tf.constant([0.0, 1.0]),
tf.constant([1.0, 1.0]),
tf.constant([0.0, 0.0])
]
assigner = targetassigner.CenterNetTrackTargetAssigner(
stride=4, num_track_ids=3)
(batch_indices, batch_weights,
track_targets) = assigner.assign_track_targets(
height=80,
width=80,
gt_track_ids_list=track_id_batch,
gt_boxes_list=box_batch,
gt_weights_list=weights_batch)
return batch_indices, batch_weights, track_targets
indices, weights, track_ids = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (3, 2, 3))
self.assertEqual(track_ids.shape, (3, 2, 3))
self.assertEqual(weights.shape, (3, 2))
np.testing.assert_array_equal(indices,
[[[0, 10, 10], [0, 15, 5]],
[[1, 15, 5], [1, 10, 10]],
[[2, 10, 10], [2, 7, 11]]])
np.testing.assert_array_equal(track_ids,
[[[1, 0, 0], [0, 1, 0]],
[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 0, 1]]])
np.testing.assert_array_equal(weights, [[0, 1], [1, 1], [0, 0]])
# TODO(xwwang): Add a test for the case when no objects are detected.
class CornerOffsetTargetAssignerTest(test_case.TestCase):
def test_filter_overlap_min_area_empty(self):
......
......@@ -138,7 +138,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_multiclass_scores=False,
load_context_features=False,
expand_hierarchy_labels=False,
load_dense_pose=False):
load_dense_pose=False,
load_track_id=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
......@@ -170,6 +171,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
classes, the labels are extended to ancestor. For negative classes,
the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations.
load_track_id: Whether to load tracking annotations.
Raises:
ValueError: If `instance_mask_type` option is not one of
......@@ -367,6 +369,12 @@ class TfExampleDecoder(data_decoder.DataDecoder):
'image/object/densepose/u', 'image/object/densepose/v',
'image/object/densepose/num'],
self._dense_pose_surface_coordinates))
if load_track_id:
self.keys_to_features['image/object/track/label'] = (
tf.VarLenFeature(tf.int64))
self.items_to_handlers[
fields.InputDataFields.groundtruth_track_ids] = (
slim_example_decoder.Tensor('image/object/track/label'))
if label_map_proto_file:
# If the label_map_proto is provided, try to use it in conjunction with
......@@ -552,6 +560,11 @@ class TfExampleDecoder(data_decoder.DataDecoder):
tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids],
dtype=tf.int32)
if fields.InputDataFields.groundtruth_track_ids in tensor_dict:
tensor_dict[fields.InputDataFields.groundtruth_track_ids] = tf.cast(
tensor_dict[fields.InputDataFields.groundtruth_track_ids],
dtype=tf.int32)
return tensor_dict
def _reshape_keypoints(self, keys_to_tensors):
......
......@@ -1430,6 +1430,48 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(dp_part_ids, expected_dp_part_ids)
self.assertAllClose(dp_surface_coords, expected_dp_surface_coords)
def testDecodeTrack(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0, 2.0]
bbox_xmins = [1.0, 5.0, 8.0]
bbox_ymaxs = [2.0, 6.0, 1.0]
bbox_xmaxs = [3.0, 7.0, 3.3]
track_labels = [0, 1, 2]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/track/label':
dataset_util.int64_list_feature(track_labels),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_track_id=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
track_ids = output[fields.InputDataFields.groundtruth_track_ids]
return track_ids
track_ids = self.execute_cpu(graph_fn, [])
expected_track_labels = [0, 1, 2]
self.assertAllEqual(track_ids, expected_track_labels)
if __name__ == '__main__':
tf.test.main()
......@@ -499,6 +499,9 @@ def pad_input_data_to_static_shapes(tensor_dict,
padding_shapes[
fields.InputDataFields.groundtruth_dp_surface_coords] = [
max_num_boxes, max_dp_points, 4]
if fields.InputDataFields.groundtruth_track_ids in tensor_dict:
padding_shapes[
fields.InputDataFields.groundtruth_track_ids] = [max_num_boxes]
# Prepare for ContextRCNN related fields.
if fields.InputDataFields.context_features in tensor_dict:
......@@ -602,7 +605,8 @@ def _get_labels_dict(input_dict):
fields.InputDataFields.groundtruth_keypoint_weights,
fields.InputDataFields.groundtruth_dp_num_points,
fields.InputDataFields.groundtruth_dp_part_ids,
fields.InputDataFields.groundtruth_dp_surface_coords
fields.InputDataFields.groundtruth_dp_surface_coords,
fields.InputDataFields.groundtruth_track_ids
]
for key in optional_label_keys:
......@@ -762,6 +766,8 @@ def train_input(train_config, train_input_config,
DensePose surface coordinates. The format is (y, x, v, u), where (y, x)
are normalized image coordinates and (v, u) are normalized surface part
coordinates.
labels[fields.InputDataFields.groundtruth_track_ids] is a
[batch_size, num_boxes] int32 tensor with the track ID for each object.
Raises:
TypeError: if the `train_config`, `train_input_config` or `model_config`
......@@ -914,6 +920,8 @@ def eval_input(eval_config, eval_input_config, model_config,
DensePose surface coordinates. The format is (y, x, v, u), where (y, x)
are normalized image coordinates and (v, u) are normalized surface part
coordinates.
labels[fields.InputDataFields.groundtruth_track_ids] is a
[batch_size, num_boxes] int32 tensor with the track ID for each object.
Raises:
TypeError: if the `eval_config`, `eval_input_config` or `model_config`
......
......@@ -1569,6 +1569,22 @@ class PadInputDataToStaticShapesFnTest(test_case.TestCase):
padded_tensor_dict[fields.InputDataFields.groundtruth_dp_surface_coords]
.shape.as_list(), [3, 200, 4])
def test_pad_input_data_to_static_shapes_for_trackid(self):
input_tensor_dict = {
fields.InputDataFields.groundtruth_track_ids:
tf.constant([0, 1], dtype=tf.int32),
}
padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
tensor_dict=input_tensor_dict,
max_num_boxes=3,
num_classes=1,
spatial_image_shape=[128, 128])
self.assertAllEqual(
padded_tensor_dict[fields.InputDataFields.groundtruth_track_ids]
.shape.as_list(), [3])
def test_context_features(self):
context_memory_size = 8
context_feature_length = 10
......
......@@ -1164,6 +1164,33 @@ def gather_surface_coords_for_parts(surface_coords_cropped,
return tf.reshape(vu_coords_flattened, [max_detections, height, width, 2])
def predicted_embeddings_at_object_centers(embedding_predictions,
y_indices, x_indices):
"""Returns the predicted embeddings at specified object centers.
Args:
embedding_predictions: A float tensor of shape [batch_size, height, width,
reid_embed_size] holding predicted embeddings.
y_indices: A [batch, num_instances] int tensor holding y indices for object
centers. These indices correspond to locations in the output feature map.
x_indices: A [batch, num_instances] int tensor holding x indices for object
centers. These indices correspond to locations in the output feature map.
Returns:
A float tensor of shape [batch_size, num_objects, reid_embed_size] where
predicted embeddings are gathered at the provided locations.
"""
batch_size, _, width, _ = _get_shape(embedding_predictions, 4)
flattened_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
_, num_instances = _get_shape(flattened_indices, 2)
embeddings_flat = _flatten_spatial_dimensions(embedding_predictions)
embeddings = tf.gather(embeddings_flat, flattened_indices, batch_dims=1)
embeddings = tf.reshape(embeddings, [batch_size, num_instances, -1])
return embeddings
class ObjectDetectionParams(
collections.namedtuple('ObjectDetectionParams', [
'localization_loss', 'scale_loss_weight', 'offset_loss_weight',
......@@ -1464,6 +1491,42 @@ class DensePoseParams(
task_loss_weight, upsample_to_input_res,
upsample_method, heatmap_bias_init)
class TrackParams(
collections.namedtuple('TrackParams', [
'num_track_ids', 'reid_embed_size', 'num_fc_layers',
'classification_loss', 'task_loss_weight'
])):
"""Namedtuple to store tracking prediction related parameters."""
__slots__ = ()
def __new__(cls,
num_track_ids,
reid_embed_size,
num_fc_layers,
classification_loss,
task_loss_weight=1.0):
"""Constructor with default values for TrackParams.
Args:
num_track_ids: int. The maximum track ID in the dataset. Used for ReID
embedding classification task.
reid_embed_size: int. The embedding size for ReID task.
num_fc_layers: int. The number of (fully-connected, batch-norm, relu)
layers for track ID classification head.
classification_loss: an object_detection.core.losses.Loss object to
compute the loss for the ReID embedding in CenterNet.
task_loss_weight: float, the loss weight for the tracking task.
Returns:
An initialized TrackParams namedtuple.
"""
return super(TrackParams,
cls).__new__(cls, num_track_ids, reid_embed_size,
num_fc_layers, classification_loss,
task_loss_weight)
# The following constants are used to generate the keys of the
# (prediction, loss, target assigner,...) dictionaries used in CenterNetMetaArch
# class.
......@@ -1480,6 +1543,8 @@ DENSEPOSE_TASK = 'densepose_task'
DENSEPOSE_HEATMAP = 'densepose/heatmap'
DENSEPOSE_REGRESSION = 'densepose/regression'
LOSS_KEY_PREFIX = 'Loss'
TRACK_TASK = 'track_task'
TRACK_REID = 'track/reid'
def get_keypoint_name(task_name, head_name):
......@@ -1523,7 +1588,8 @@ class CenterNetMetaArch(model.DetectionModel):
object_detection_params=None,
keypoint_params_dict=None,
mask_params=None,
densepose_params=None):
densepose_params=None,
track_params=None):
"""Initializes a CenterNet model.
Args:
......@@ -1555,6 +1621,9 @@ class CenterNetMetaArch(model.DetectionModel):
hyper-parameters for DensePose prediction. Please see the class
definition for more details. Note that if this is provided, it is
expected that `mask_params` is also provided.
track_params: A TrackParams namedtuple. This object
holds the hyper-parameters for tracking. Please see the class
definition for more details.
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -1574,6 +1643,7 @@ class CenterNetMetaArch(model.DetectionModel):
raise ValueError('To run DensePose prediction, `mask_params` must also '
'be supplied.')
self._densepose_params = densepose_params
self._track_params = track_params
# Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads(
......@@ -1613,7 +1683,8 @@ class CenterNetMetaArch(model.DetectionModel):
Returns:
A dictionary of keras modules generated by calling make_prediction_net
function.
function. It will also create and set a private member of the class when
learning the tracking task.
"""
prediction_heads = {}
prediction_heads[OBJECT_CENTER] = [
......@@ -1666,6 +1737,26 @@ class CenterNetMetaArch(model.DetectionModel):
make_prediction_net(2 * self._densepose_params.num_parts)
for _ in range(num_feature_outputs)
]
if self._track_params is not None:
prediction_heads[TRACK_REID] = [
make_prediction_net(self._track_params.reid_embed_size)
for _ in range(num_feature_outputs)]
# Creates a classification network to train object embeddings by learning
# a projection from embedding space to object track ID space.
self.track_reid_classification_net = tf.keras.Sequential()
for _ in range(self._track_params.num_fc_layers - 1):
self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.reid_embed_size,
input_shape=(
self._track_params.reid_embed_size,)))
self.track_reid_classification_net.add(
tf.keras.layers.BatchNormalization())
self.track_reid_classification_net.add(tf.keras.layers.ReLU())
self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.num_track_ids,
input_shape=(
self._track_params.reid_embed_size,)))
return prediction_heads
def _initialize_target_assigners(self, stride, min_box_overlap_iou):
......@@ -1704,6 +1795,10 @@ class CenterNetMetaArch(model.DetectionModel):
dp_stride = 1 if self._densepose_params.upsample_to_input_res else stride
target_assigners[DENSEPOSE_TASK] = (
cn_assigner.CenterNetDensePoseTargetAssigner(dp_stride))
if self._track_params is not None:
target_assigners[TRACK_TASK] = (
cn_assigner.CenterNetTrackTargetAssigner(
stride, self._track_params.num_track_ids))
return target_assigners
......@@ -2222,6 +2317,76 @@ class CenterNetMetaArch(model.DetectionModel):
num_predictions * num_valid_points)
return part_prediction_loss, surface_coord_loss
def _compute_track_losses(self, input_height, input_width, prediction_dict):
"""Computes all the losses associated with tracking.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: The dictionary returned from the predict() method.
Returns:
A dictionary with tracking losses.
"""
object_reid_predictions = prediction_dict[TRACK_REID]
embedding_loss = self._compute_track_embedding_loss(
input_height=input_height,
input_width=input_width,
object_reid_predictions=object_reid_predictions)
losses = {
TRACK_REID: embedding_loss
}
return losses
def _compute_track_embedding_loss(self, input_height, input_width,
object_reid_predictions):
"""Computes the object ReID loss.
The embedding is trained as a classification task where the target is the
ID of each track among all tracks in the whole dataset.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
object_reid_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, reid_embed_size] representing the object
embedding feature maps.
Returns:
A float scalar tensor representing the object ReID loss per instance.
"""
gt_track_ids_list = self.groundtruth_lists(fields.BoxListFields.track_ids)
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
num_boxes = _to_float32(get_num_instances_from_weights(gt_weights_list))
# Convert the groundtruth to targets.
assigner = self._target_assigner_dict[TRACK_TASK]
batch_indices, batch_weights, track_targets = assigner.assign_track_targets(
height=input_height,
width=input_width,
gt_track_ids_list=gt_track_ids_list,
gt_boxes_list=gt_boxes_list,
gt_weights_list=gt_weights_list)
batch_weights = tf.expand_dims(batch_weights, -1)
loss = 0.0
object_reid_loss = self._track_params.classification_loss
# Loop through each feature output head.
for pred in object_reid_predictions:
embedding_pred = cn_assigner.get_batch_predictions_from_indices(
pred, batch_indices)
reid_classification = self.track_reid_classification_net(embedding_pred)
loss += object_reid_loss(
reid_classification, track_targets, weights=batch_weights)
loss_per_instance = tf.reduce_sum(loss) / (
float(len(object_reid_predictions)) * num_boxes)
return loss_per_instance
def preprocess(self, inputs):
outputs = shape_utils.resize_images_and_return_shapes(
inputs, self._image_resizer_fn)
......@@ -2316,7 +2481,8 @@ class CenterNetMetaArch(model.DetectionModel):
'Loss/$TASK_NAME/keypoint/regression', (optional)
'Loss/segmentation/heatmap', (optional)
'Loss/densepose/heatmap', (optional)
'Loss/densepose/regression]' (optional)
'Loss/densepose/regression', (optional)
'Loss/track/reid'] (optional)
scalar tensors corresponding to the losses for different tasks. Note the
$TASK_NAME is provided by the KeypointEstimation namedtuple used to
differentiate between different keypoint tasks.
......@@ -2384,6 +2550,16 @@ class CenterNetMetaArch(model.DetectionModel):
densepose_losses[key] * self._densepose_params.task_loss_weight)
losses.update(densepose_losses)
if self._track_params is not None:
track_losses = self._compute_track_losses(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict)
for key in track_losses:
track_losses[key] = (
track_losses[key] * self._track_params.task_loss_weight)
losses.update(track_losses)
# Prepend the LOSS_KEY_PREFIX to the keys in the dictionary such that the
# losses will be grouped together in Tensorboard.
return dict([('%s/%s' % (LOSS_KEY_PREFIX, key), val)
......@@ -2426,6 +2602,8 @@ class CenterNetMetaArch(model.DetectionModel):
detection_surface_coords: (Optional) A float32 tensor of shape [batch,
max_detection, mask_height, mask_width, 2] with DensePose surface
coordinates, in (v, u) format.
detection_embeddings: (Optional) A float tensor of shape [batch,
max_detections, reid_embed_size] containing object embeddings.
"""
object_center_prob = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
# Get x, y and channel indices corresponding to the top indices in the class
......@@ -2487,8 +2665,39 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.detection_surface_coords] = (
surface_coords)
if self._track_params:
embeddings = self._postprocess_embeddings(prediction_dict,
y_indices, x_indices)
postprocess_dict.update({
fields.DetectionResultFields.detection_embeddings: embeddings
})
return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
"""Performs postprocessing on embedding predictions.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain embedding prediction
feature maps for tracking task.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
Returns:
embeddings: A [batch_size, max_detection, reid_embed_size] float32
tensor with L2 normalized embeddings extracted from detection box
centers.
"""
embedding_predictions = prediction_dict[TRACK_REID][-1]
embeddings = predicted_embeddings_at_object_centers(
embedding_predictions, y_indices, x_indices)
embeddings, _ = tf.linalg.normalize(embeddings, axis=-1)
return embeddings
def _postprocess_keypoints(self, prediction_dict, classes, y_indices,
x_indices, boxes, num_detections):
"""Performs postprocessing on keypoint predictions.
......
......@@ -1046,6 +1046,41 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(kpt_scores_np[:, i, :],
kpt_scores_padded[:, inst_ind, :])
def test_predicted_embeddings_at_object_centers(self):
batch_size = 2
embedding_size = 5
num_instances = 6
predicted_embedding_feature_map_np = np.random.randn(
batch_size, 10, 10, embedding_size).astype(np.float32)
y_indices = np.random.choice(10, (batch_size, num_instances))
x_indices = np.random.choice(10, (batch_size, num_instances))
def graph_fn():
predicted_embedding_feature_map = tf.constant(
predicted_embedding_feature_map_np, dtype=tf.float32)
gathered_predicted_embeddings = (
cnma.predicted_embeddings_at_object_centers(
predicted_embedding_feature_map,
tf.constant(y_indices, dtype=tf.int32),
tf.constant(x_indices, dtype=tf.int32)))
return gathered_predicted_embeddings
gathered_predicted_embeddings = self.execute(graph_fn, [])
expected_gathered_embeddings_0 = predicted_embedding_feature_map_np[
0, y_indices[0], x_indices[0], :]
expected_gathered_embeddings_1 = predicted_embedding_feature_map_np[
1, y_indices[1], x_indices[1], :]
expected_gathered_embeddings = np.stack([
expected_gathered_embeddings_0,
expected_gathered_embeddings_1], axis=0)
expected_gathered_embeddings = np.reshape(
expected_gathered_embeddings,
[batch_size, num_instances, embedding_size])
np.testing.assert_allclose(expected_gathered_embeddings,
gathered_predicted_embeddings)
# Common parameters for setting up testing examples across tests.
_NUM_CLASSES = 10
......@@ -1053,6 +1088,9 @@ _KEYPOINT_INDICES = [0, 1, 2, 3]
_NUM_KEYPOINTS = len(_KEYPOINT_INDICES)
_DENSEPOSE_NUM_PARTS = 24
_TASK_NAME = 'human_pose'
_NUM_TRACK_IDS = 3
_REID_EMBED_SIZE = 2
_NUM_FC_LAYERS = 1
def get_fake_center_params():
......@@ -1108,6 +1146,16 @@ def get_fake_densepose_params():
upsample_method='nearest')
def get_fake_track_params():
"""Returns the fake object tracking parameter namedtuple."""
return cnma.TrackParams(
num_track_ids=_NUM_TRACK_IDS,
reid_embed_size=_REID_EMBED_SIZE,
num_fc_layers=_NUM_FC_LAYERS,
classification_loss=losses.WeightedSoftmaxClassificationLoss(),
task_loss_weight=1.0)
def build_center_net_meta_arch(build_resnet=False):
"""Builds the CenterNet meta architecture."""
if build_resnet:
......@@ -1136,7 +1184,8 @@ def build_center_net_meta_arch(build_resnet=False):
object_detection_params=get_fake_od_params(),
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()},
mask_params=get_fake_mask_params(),
densepose_params=get_fake_densepose_params())
densepose_params=get_fake_densepose_params(),
track_params=get_fake_track_params())
def _logit(p):
......@@ -1230,6 +1279,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
fake_feature_map)
self.assertEqual((4, 128, 128, 2 * _DENSEPOSE_NUM_PARTS), output.shape)
# "track embedding" head:
output = model._prediction_head_dict[cnma.TRACK_REID][-1](
fake_feature_map)
self.assertEqual((4, 128, 128, _REID_EMBED_SIZE), output.shape)
def test_initialize_target_assigners(self):
model = build_center_net_meta_arch()
assigner_dict = model._initialize_target_assigners(
......@@ -1257,6 +1311,10 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertIsInstance(assigner_dict[cnma.DENSEPOSE_TASK],
cn_assigner.CenterNetDensePoseTargetAssigner)
# Track estimation target assigner:
self.assertIsInstance(assigner_dict[cnma.TRACK_TASK],
cn_assigner.CenterNetTrackTargetAssigner)
def test_predict(self):
"""Test the predict function."""
......@@ -1281,6 +1339,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
(2, 32, 32, _DENSEPOSE_NUM_PARTS))
self.assertEqual(prediction_dict[cnma.DENSEPOSE_REGRESSION][0].shape,
(2, 32, 32, 2 * _DENSEPOSE_NUM_PARTS))
self.assertEqual(prediction_dict[cnma.TRACK_REID][0].shape,
(2, 32, 32, _REID_EMBED_SIZE))
def test_loss(self):
"""Test the loss function."""
......@@ -1299,7 +1359,16 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
groundtruth_dp_part_ids_list=groundtruth_dict[
fields.BoxListFields.densepose_part_ids],
groundtruth_dp_surface_coords_list=groundtruth_dict[
fields.BoxListFields.densepose_surface_coords])
fields.BoxListFields.densepose_surface_coords],
groundtruth_track_ids_list=groundtruth_dict[
fields.BoxListFields.track_ids])
kernel_initializer = tf.constant_initializer(
[[1, 1, 0], [-1000000, -1000000, 1000000]])
model.track_reid_classification_net = tf.keras.layers.Dense(
_NUM_TRACK_IDS,
kernel_initializer=kernel_initializer,
input_shape=(_REID_EMBED_SIZE,))
prediction_dict = get_fake_prediction_dict(
input_height=16, input_width=32, stride=4)
......@@ -1341,6 +1410,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertGreater(
0.01, loss_dict['%s/%s' % (cnma.LOSS_KEY_PREFIX,
cnma.DENSEPOSE_REGRESSION)])
self.assertGreater(
0.01, loss_dict['%s/%s' % (cnma.LOSS_KEY_PREFIX,
cnma.TRACK_REID)])
@parameterized.parameters(
{'target_class_id': 1},
......@@ -1386,6 +1458,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
dp_surf_coords = np.random.randn(1, 32, 32, 2 * _DENSEPOSE_NUM_PARTS)
embedding_size = 100
track_reid_embedding = np.zeros((1, 32, 32, embedding_size),
dtype=np.float32)
track_reid_embedding[0, 16, 16, :] = np.ones(embedding_size)
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
......@@ -1395,6 +1472,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
segmentation_heatmap = tf.constant(segmentation_heatmap, dtype=tf.float32)
dp_part_heatmap = tf.constant(dp_part_heatmap, dtype=tf.float32)
dp_surf_coords = tf.constant(dp_surf_coords, dtype=tf.float32)
track_reid_embedding = tf.constant(track_reid_embedding, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
......@@ -1408,7 +1486,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
[keypoint_regression],
cnma.SEGMENTATION_HEATMAP: [segmentation_heatmap],
cnma.DENSEPOSE_HEATMAP: [dp_part_heatmap],
cnma.DENSEPOSE_REGRESSION: [dp_surf_coords]
cnma.DENSEPOSE_REGRESSION: [dp_surf_coords],
cnma.TRACK_REID: [track_reid_embedding]
}
def graph_fn():
......@@ -1422,6 +1501,14 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0],
[.75, .5, .5, .5, .5])
# The output embedding extracted at the object center will be a 3-D array of
# shape [batch, num_boxes, embedding_size]. The valid predicted embedding
# will be the first embedding in the first batch. It is a 1-D array of
# shape [embedding_size] with values all ones. All the values of the
# embedding will then be divided by the square root of 'embedding_size'
# after the L2 normalization.
self.assertAllClose(detections['detection_embeddings'][0, 0],
np.ones(embedding_size) / embedding_size**0.5)
self.assertEqual(detections['detection_classes'][0, 0], target_class_id)
self.assertEqual(detections['num_detections'], [5])
self.assertAllEqual([1, max_detection, num_keypoints, 2],
......@@ -1430,6 +1517,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_keypoint_scores'].shape)
self.assertAllEqual([1, max_detection, 4, 4],
detections['detection_masks'].shape)
self.assertAllEqual([1, max_detection, embedding_size],
detections['detection_embeddings'].shape)
# Masks should be empty for everything but the first detection.
self.assertAllEqual(
......@@ -1539,6 +1628,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
# (5 * 2, 5 * 2 + 1), or (10, 11).
densepose_regression[0, 2, 4, 10:12] = 0.4, 0.7
track_reid_embedding = np.zeros((2, output_height, output_width,
_REID_EMBED_SIZE), dtype=np.float32)
track_reid_embedding[0, 2, 4, :] = np.arange(_REID_EMBED_SIZE)
prediction_dict = {
'preprocessed_inputs':
tf.zeros((2, input_height, input_width, 3)),
......@@ -1577,6 +1670,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
cnma.DENSEPOSE_REGRESSION: [
tf.constant(densepose_regression),
tf.constant(densepose_regression),
],
cnma.TRACK_REID: [
tf.constant(track_reid_embedding),
tf.constant(track_reid_embedding),
]
}
return prediction_dict
......@@ -1635,6 +1732,10 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
tf.constant(densepose_surface_coords_np),
tf.zeros_like(densepose_surface_coords_np)
]
track_ids = [
tf.constant([2], dtype=tf.int32),
tf.constant([1], dtype=tf.int32),
]
groundtruth_dict = {
fields.BoxListFields.boxes: boxes,
fields.BoxListFields.weights: weights,
......@@ -1645,6 +1746,7 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
fields.BoxListFields.densepose_part_ids: densepose_part_ids,
fields.BoxListFields.densepose_surface_coords:
densepose_surface_coords,
fields.BoxListFields.track_ids: track_ids,
fields.InputDataFields.groundtruth_labeled_classes: labeled_classes,
}
return groundtruth_dict
......@@ -1778,6 +1880,27 @@ class CenterNetMetaComputeLossTest(test_case.TestCase):
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
def test_compute_track_embedding_loss(self):
default_fc = self.model.track_reid_classification_net
# Initialize the kernel to extreme values so that the classification score
# is close to (0, 0, 1) after the softmax layer.
kernel_initializer = tf.constant_initializer(
[[1, 1, 0], [-1000000, -1000000, 1000000]])
self.model.track_reid_classification_net = tf.keras.layers.Dense(
_NUM_TRACK_IDS,
kernel_initializer=kernel_initializer,
input_shape=(_REID_EMBED_SIZE,))
loss = self.model._compute_track_embedding_loss(
input_height=self.input_height,
input_width=self.input_width,
object_reid_predictions=self.prediction_dict[cnma.TRACK_REID])
self.model.track_reid_classification_net = default_fc
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaArchRestoreTest(test_case.TestCase):
......
......@@ -102,6 +102,8 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
'groundtruth_dp_surface_coords_list': [batch_size, num_boxes,
max_sampled_points, 4] containing the DensePose surface coordinates for
each sampled point (if provided in groundtruth).
'groundtruth_track_ids_list': [batch_size, num_boxes] int32 tensor
with track ID for each instance (if provided in groundtruth).
'groundtruth_group_of': [batch_size, num_boxes] bool tensor indicating
group_of annotations (if provided in groundtruth).
'groundtruth_labeled_classes': [batch_size, num_classes] int64
......@@ -187,6 +189,11 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
groundtruth[input_data_fields.groundtruth_dp_surface_coords] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.densepose_surface_coords))
if detection_model.groundtruth_has_field(fields.BoxListFields.track_ids):
groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.track_ids))
groundtruth[input_data_fields.num_groundtruth_boxes] = (
tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
return groundtruth
......@@ -245,6 +252,7 @@ def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
fields.InputDataFields.groundtruth_dp_num_points,
fields.InputDataFields.groundtruth_dp_part_ids,
fields.InputDataFields.groundtruth_dp_surface_coords,
fields.InputDataFields.groundtruth_track_ids,
fields.InputDataFields.groundtruth_group_of,
fields.InputDataFields.groundtruth_difficult,
fields.InputDataFields.groundtruth_is_crowd,
......@@ -307,6 +315,10 @@ def provide_groundtruth(model, labels):
if fields.InputDataFields.groundtruth_dp_surface_coords in labels:
gt_dp_surface_coords_list = labels[
fields.InputDataFields.groundtruth_dp_surface_coords]
gt_track_ids_list = None
if fields.InputDataFields.groundtruth_track_ids in labels:
gt_track_ids_list = labels[
fields.InputDataFields.groundtruth_track_ids]
gt_weights_list = None
if fields.InputDataFields.groundtruth_weights in labels:
gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
......@@ -341,7 +353,8 @@ def provide_groundtruth(model, labels):
groundtruth_weights_list=gt_weights_list,
groundtruth_is_crowd_list=gt_is_crowd_list,
groundtruth_group_of_list=gt_group_of_list,
groundtruth_area_list=gt_area_list)
groundtruth_area_list=gt_area_list,
groundtruth_track_ids_list=gt_track_ids_list)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
......
......@@ -104,6 +104,8 @@ def _compute_losses_and_predictions_dicts(
containing group_of annotations.
labels[fields.InputDataFields.groundtruth_labeled_classes] is a float32
k-hot tensor of classes.
labels[fields.InputDataFields.groundtruth_track_ids] is a int32
tensor of track IDs.
add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary.
......@@ -216,6 +218,8 @@ def eager_train_step(detection_model,
(v, u) are part-relative normalized surface coordinates.
labels[fields.InputDataFields.groundtruth_labeled_classes] is a float32
k-hot tensor of classes.
labels[fields.InputDataFields.groundtruth_track_ids] is a int32
tensor of track IDs.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables.
learning_rate: The learning rate tensor for the current training step.
......
......@@ -218,6 +218,32 @@ message CenterNet {
optional float heatmap_bias_init = 8 [default = -2.19];
}
optional DensePoseEstimation densepose_estimation_task = 9;
// Parameters which are related to tracking embedding estimation task.
// A Simple Baseline for Multi-Object Tracking [2]
// [2]: https://arxiv.org/abs/2004.01888
message TrackEstimation {
// Weight of the task loss. The total loss of the model will be the
// summation of task losses weighted by the weights.
optional float task_loss_weight = 1 [default = 1.0];
// The maximun track ID of the datset.
optional int32 num_track_ids = 2;
// The embedding size for re-identification (ReID) task in tracking.
optional int32 reid_embed_size = 3 [default = 128];
// The number of (fully-connected, batch-norm, relu) layers for track ID
// classification head. The output dimension of each intermediate FC layer
// will all be 'reid_embed_size'. The last FC layer will directly project to
// the track ID classification space of size 'num_track_ids' without
// batch-norm and relu layers.
optional int32 num_fc_layers = 4 [default = 1];
// Classification loss configuration for ReID loss.
optional ClassificationLoss classification_loss = 5;
}
optional TrackEstimation track_estimation_task = 10;
}
message CenterNetFeatureExtractor {
......
......@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
}
// Next id: 33
// Next id: 34
message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
......@@ -122,6 +122,9 @@ message InputReader {
// to true.
optional bool load_dense_pose = 31 [default = false];
// Whether to load track information.
optional bool load_track_id = 33 [default = false];
// Whether to use the display name when decoding examples. This is only used
// when mapping class text strings to integers.
optional bool use_display_name = 17 [default = false];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册