提交 bb43ed96 编写于 作者: Z Zhichao Lu 提交者: pkulzc

Write groundtruth weights from input pipeline into model.

PiperOrigin-RevId: 190636417
上级 45069b91
......@@ -160,6 +160,52 @@ def augment_input_data(tensor_dict, data_augmentation_options):
return tensor_dict
def _get_labels_dict(input_dict):
"""Extracts labels dict from input dict."""
required_label_keys = [
fields.InputDataFields.num_groundtruth_boxes,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_weights
]
labels_dict = {}
for key in required_label_keys:
labels_dict[key] = input_dict[key]
optional_label_keys = [
fields.InputDataFields.groundtruth_keypoints,
fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_difficult
]
for key in optional_label_keys:
if key in input_dict:
labels_dict[key] = input_dict[key]
if fields.InputDataFields.groundtruth_difficult in labels_dict:
labels_dict[fields.InputDataFields.groundtruth_difficult] = tf.cast(
labels_dict[fields.InputDataFields.groundtruth_difficult], tf.int32)
return labels_dict
def _get_features_dict(input_dict):
"""Extracts features dict from input dict."""
hash_from_source_id = tf.string_to_hash_bucket_fast(
input_dict[fields.InputDataFields.source_id], HASH_BINS)
features = {
fields.InputDataFields.image:
input_dict[fields.InputDataFields.image],
HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
fields.InputDataFields.true_image_shape:
input_dict[fields.InputDataFields.true_image_shape]
}
if fields.InputDataFields.original_image in input_dict:
features[fields.InputDataFields.original_image] = input_dict[
fields.InputDataFields.original_image]
return features
def create_train_input_fn(train_config, train_input_config,
model_config):
"""Creates a train `input` function for `Estimator`.
......@@ -249,38 +295,8 @@ def create_train_input_fn(train_config, train_input_config,
num_classes=config_util.get_number_of_classes(model_config),
spatial_image_shape=config_util.get_spatial_image_size(
image_resizer_config))
tensor_dict = dataset_util.make_initializable_iterator(dataset).get_next()
hash_from_source_id = tf.string_to_hash_bucket_fast(
tensor_dict[fields.InputDataFields.source_id], HASH_BINS)
features = {
fields.InputDataFields.image: tensor_dict[fields.InputDataFields.image],
HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
fields.InputDataFields.true_image_shape: tensor_dict[
fields.InputDataFields.true_image_shape]
}
if fields.InputDataFields.original_image in tensor_dict:
features[fields.InputDataFields.original_image] = tensor_dict[
fields.InputDataFields.original_image]
labels = {
fields.InputDataFields.num_groundtruth_boxes: tensor_dict[
fields.InputDataFields.num_groundtruth_boxes],
fields.InputDataFields.groundtruth_boxes: tensor_dict[
fields.InputDataFields.groundtruth_boxes],
fields.InputDataFields.groundtruth_classes: tensor_dict[
fields.InputDataFields.groundtruth_classes],
fields.InputDataFields.groundtruth_weights: tensor_dict[
fields.InputDataFields.groundtruth_weights]
}
if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
labels[fields.InputDataFields.groundtruth_keypoints] = tensor_dict[
fields.InputDataFields.groundtruth_keypoints]
if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
labels[fields.InputDataFields.groundtruth_instance_masks] = tensor_dict[
fields.InputDataFields.groundtruth_instance_masks]
return features, labels
input_dict = dataset_util.make_initializable_iterator(dataset).get_next()
return (_get_features_dict(input_dict), _get_labels_dict(input_dict))
return _train_input_fn
......@@ -365,36 +381,7 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
image_resizer_config))
input_dict = dataset_util.make_initializable_iterator(dataset).get_next()
hash_from_source_id = tf.string_to_hash_bucket_fast(
input_dict[fields.InputDataFields.source_id], HASH_BINS)
features = {
fields.InputDataFields.image:
input_dict[fields.InputDataFields.image],
fields.InputDataFields.original_image:
input_dict[fields.InputDataFields.original_image],
HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
fields.InputDataFields.true_image_shape:
input_dict[fields.InputDataFields.true_image_shape]
}
labels = {
fields.InputDataFields.groundtruth_boxes:
input_dict[fields.InputDataFields.groundtruth_boxes],
fields.InputDataFields.groundtruth_classes:
input_dict[fields.InputDataFields.groundtruth_classes],
fields.InputDataFields.groundtruth_area:
input_dict[fields.InputDataFields.groundtruth_area],
fields.InputDataFields.groundtruth_is_crowd:
input_dict[fields.InputDataFields.groundtruth_is_crowd],
fields.InputDataFields.groundtruth_difficult:
tf.cast(input_dict[fields.InputDataFields.groundtruth_difficult],
tf.int32)
}
if fields.InputDataFields.groundtruth_instance_masks in input_dict:
labels[fields.InputDataFields.groundtruth_instance_masks] = input_dict[
fields.InputDataFields.groundtruth_instance_masks]
return features, labels
return (_get_features_dict(input_dict), _get_labels_dict(input_dict))
return _eval_input_fn
......
......@@ -241,7 +241,9 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
groundtruth_boxes_list=gt_boxes_list,
groundtruth_classes_list=gt_classes_list,
groundtruth_masks_list=gt_masks_list,
groundtruth_keypoints_list=gt_keypoints_list)
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_weights_list=labels[
fields.InputDataFields.groundtruth_weights])
preprocessed_images = features[fields.InputDataFields.image]
prediction_dict = detection_model.predict(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册