提交 ad32e81e 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 524770642
上级 5d10b3bc
......@@ -85,6 +85,31 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info
def _normalize_coordinates(self, detections_dict, dict_keys, image_info):
"""Normalizes detection coordinates between 0 and 1.
Args:
detections_dict: Dictionary containing the output of the model prediction.
dict_keys: Key names corresponding to the tensors of the output dictionary
that we want to update.
image_info: Tensor containing the details of the image resizing.
Returns:
detections_dict: Updated detection dictionary.
"""
for key in dict_keys:
if key not in detections_dict:
continue
detection_boxes = detections_dict[key] / tf.tile(
image_info[:, 2:3, :], [1, 1, 2]
)
detections_dict[key] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :]
)
detections_dict[key] = tf.clip_by_value(detections_dict[key], 0.0, 1.0)
return detections_dict
def preprocess(
self, images: tf.Tensor
) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
......@@ -180,13 +205,8 @@ class DetectionModule(export_base.ExportModule):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
for key in ['detection_boxes', 'detection_outer_boxes']:
if key not in detections:
continue
detection_boxes = (
detections[key] / tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
detections[key] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :])
keys = ['detection_boxes', 'detection_outer_boxes']
detections = self._normalize_coordinates(detections, keys, image_info)
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
......@@ -208,6 +228,13 @@ class DetectionModule(export_base.ExportModule):
final_outputs['detection_outer_boxes'] = (
detections['detection_outer_boxes'])
else:
# For RetinaNet model, apply export_config.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
keys = ['decoded_boxes']
detections = self._normalize_coordinates(detections, keys, image_info)
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
......
......@@ -34,12 +34,17 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
experiment_name,
input_type,
outer_boxes_scale=1.0,
apply_nms=True,
normalized_coordinates=False,
nms_version='batched',
output_intermediate_features=False,
):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.outer_boxes_scale = outer_boxes_scale
params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.apply_nms = apply_nms
if normalized_coordinates:
params.task.export_config.output_normalized_coordinates = True
params.task.model.detection_generator.nms_version = nms_version
if output_intermediate_features:
params.task.export_config.output_intermediate_features = True
......@@ -184,6 +189,49 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
outputs.keys(),
)
@parameterized.parameters(
('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('tflite', 'retinanet_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
('tflite', 'retinanet_spinenet_coco', [640, 640]),
)
def test_export_normalized_coordinates_no_nms(
self,
input_type,
experiment_name,
image_size,
):
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(
experiment_name,
input_type,
apply_nms=False,
normalized_coordinates=True,
)
self._export_from_module(module, input_type, tmp_dir)
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size
)
outputs = detection_fn(tf.constant(images))
min_values = tf.math.reduce_min(outputs['decoded_boxes'])
max_values = tf.math.reduce_max(outputs['decoded_boxes'])
self.assertAllGreaterEqual(
min_values.numpy(), tf.zeros_like(min_values).numpy()
)
self.assertAllLessEqual(
max_values.numpy(), tf.ones_like(max_values).numpy()
)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册