From 4992ebfb95dd0f8c952e3014aeb7fc5ed1de8f93 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Mon, 27 Feb 2023 17:01:25 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 512771693 --- official/core/config_definitions.py | 3 + official/core/train_lib.py | 46 +++-- official/vision/tasks/retinanet.py | 35 +++- official/vision/train.py | 13 +- .../object_detection/visualization_utils.py | 163 ++++++++++++++++-- official/vision/utils/summary_manager.py | 84 +++++++++ 6 files changed, 312 insertions(+), 32 deletions(-) create mode 100644 official/vision/utils/summary_manager.py diff --git a/official/core/config_definitions.py b/official/core/config_definitions.py index 50ab8058a..7c2772657 100644 --- a/official/core/config_definitions.py +++ b/official/core/config_definitions.py @@ -289,6 +289,9 @@ class TaskConfig(base_config.Config): # DEPRECATED b/264611883 differential_privacy_config: Optional[ dp_configs.DifferentialPrivacyConfig] = None + # Whether to show image summary. Useful to visualize model predictions. Only + # work for vision tasks. + allow_image_summary: bool = False @dataclasses.dataclass diff --git a/official/core/train_lib.py b/official/core/train_lib.py index 378efa787..d68e40b81 100644 --- a/official/core/train_lib.py +++ b/official/core/train_lib.py @@ -68,7 +68,9 @@ class OrbitExperimentRunner: train_actions: Optional[List[orbit.Action]] = None, eval_actions: Optional[List[orbit.Action]] = None, trainer: Optional[base_trainer.Trainer] = None, - controller_cls=orbit.Controller + controller_cls=orbit.Controller, + summary_manager: Optional[orbit.utils.SummaryManager] = None, + eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, ): """Constructor. @@ -88,6 +90,10 @@ class OrbitExperimentRunner: the strategy.scope(). controller_cls: The controller class to manage the train and eval process. Must be a orbit.Controller subclass. + summary_manager: Instance of the summary manager to override default + summary manager. + eval_summary_manager: Instance of the eval summary manager to override + default eval summary manager. """ self.strategy = distribution_strategy or tf.distribute.get_strategy() self._params = params @@ -101,6 +107,8 @@ class OrbitExperimentRunner: evaluate=('eval' in mode) or run_post_eval) assert self.trainer is not None self._checkpoint_manager = self._maybe_build_checkpoint_manager() + self._summary_manager = summary_manager + self._eval_summary_manager = eval_summary_manager self._controller = self._build_controller( trainer=self.trainer if 'train' in mode else None, evaluator=self.trainer, @@ -201,6 +209,13 @@ class OrbitExperimentRunner: eval_actions += actions.get_eval_actions(self.params, evaluator, self.model_dir) + if save_summary: + eval_summary_dir = os.path.join( + self.model_dir, self.params.trainer.validation_summary_subdir + ) + else: + eval_summary_dir = None + controller = controller_cls( strategy=self.strategy, trainer=trainer, @@ -208,15 +223,18 @@ class OrbitExperimentRunner: global_step=self.trainer.global_step, steps_per_loop=self.params.trainer.steps_per_loop, checkpoint_manager=self.checkpoint_manager, - summary_dir=os.path.join(self.model_dir, 'train') if - (save_summary) else None, - eval_summary_dir=os.path.join( - self.model_dir, self.params.trainer.validation_summary_subdir) if - (save_summary) else None, - summary_interval=self.params.trainer.summary_interval if - (save_summary) else None, + summary_dir=os.path.join(self.model_dir, 'train') + if (save_summary) + else None, + eval_summary_dir=eval_summary_dir, + summary_interval=self.params.trainer.summary_interval + if (save_summary) + else None, train_actions=train_actions, - eval_actions=eval_actions) + eval_actions=eval_actions, + summary_manager=self._summary_manager, + eval_summary_manager=self._eval_summary_manager, + ) return controller def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]: @@ -284,7 +302,9 @@ def run_experiment( train_actions: Optional[List[orbit.Action]] = None, eval_actions: Optional[List[orbit.Action]] = None, trainer: Optional[base_trainer.Trainer] = None, - controller_cls=orbit.Controller + controller_cls=orbit.Controller, + summary_manager: Optional[orbit.utils.SummaryManager] = None, + eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. @@ -304,6 +324,10 @@ def run_experiment( strategy.scope(). controller_cls: The controller class to manage the train and eval process. Must be a orbit.Controller subclass. + summary_manager: Instance of the summary manager to override default summary + manager. + eval_summary_manager: Instance of the eval summary manager to override + default eval summary manager. Returns: A 2-tuple of (model, eval_logs). @@ -323,5 +347,7 @@ def run_experiment( eval_actions=eval_actions, trainer=trainer, controller_cls=controller_cls, + summary_manager=summary_manager, + eval_summary_manager=eval_summary_manager, ) return runner.run() diff --git a/official/vision/tasks/retinanet.py b/official/vision/tasks/retinanet.py index 372d5fd2a..2012b74f9 100644 --- a/official/vision/tasks/retinanet.py +++ b/official/vision/tasks/retinanet.py @@ -32,6 +32,7 @@ from official.vision.evaluation import coco_evaluator from official.vision.losses import focal_loss from official.vision.losses import loss_utils from official.vision.modeling import factory +from official.vision.utils.object_detection import visualization_utils @task_factory.register_task_cls(exp_cfg.RetinaNetTask) @@ -264,15 +265,20 @@ class RetinaNetTask(base_task.Task): metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32)) if not training: - if self.task_config.validation_data.tfds_name and self.task_config.annotation_file: + if ( + self.task_config.validation_data.tfds_name + and self.task_config.annotation_file + ): raise ValueError( - "Can't evaluate using annotation file when TFDS is used.") + "Can't evaluate using annotation file when TFDS is used." + ) if self._task_config.use_coco_metrics: self.coco_metric = coco_evaluator.COCOEvaluator( annotation_file=self.task_config.annotation_file, include_mask=False, per_category_metrics=self.task_config.per_category_metrics, - max_num_eval_detections=self.task_config.max_num_eval_detections) + max_num_eval_detections=self.task_config.max_num_eval_detections, + ) if self._task_config.use_wod_metrics: # To use Waymo open dataset metrics, please install one of the pip # package `waymo-open-dataset-tf-*` from @@ -402,6 +408,14 @@ class RetinaNetTask(base_task.Task): for m in metrics: m.update_state(all_losses[m.name]) logs.update({m.name: m.result()}) + + if ( + hasattr(self.task_config, 'allow_image_summary') + and self.task_config.allow_image_summary + ): + logs.update( + {'visualization': (tf.cast(features, dtype=tf.float32), outputs)} + ) return logs def aggregate_logs(self, state=None, step_outputs=None): @@ -418,7 +432,12 @@ class RetinaNetTask(base_task.Task): if state is None: # Create an arbitrary state to indicate it's not the first step in the # following calls to this function. - state = True + state = {} + + # Update detection state for writing summary if there are artifacts for + # visualization. + if 'visualization' in step_outputs: + state.update(visualization_utils.update_detection_state(step_outputs)) return state def reduce_aggregated_logs(self, aggregated_logs, global_step=None): @@ -427,4 +446,12 @@ class RetinaNetTask(base_task.Task): logs.update(self.coco_metric.result()) if self._task_config.use_wod_metrics: logs.update(self.wod_metric.result()) + + # Add visualization for summary. + if 'image' in aggregated_logs: + validation_outputs = visualization_utils.visualize_outputs( + logs=aggregated_logs, task_config=self.task_config + ) + logs.update({'image/validation_outputs': validation_outputs}) + return logs diff --git a/official/vision/train.py b/official/vision/train.py index f6b6f3290..5c9997878 100644 --- a/official/vision/train.py +++ b/official/vision/train.py @@ -26,9 +26,9 @@ from official.core import task_factory from official.core import train_lib from official.core import train_utils from official.modeling import performance -# pylint: disable=unused-import -from official.vision import registry_imports -# pylint: enable=unused-import +from official.vision import registry_imports # pylint: disable=unused-import +from official.vision.utils import summary_manager + FLAGS = flags.FLAGS @@ -53,7 +53,12 @@ def _run_experiment_with_preemption_recovery(params, model_dir): task=task, mode=FLAGS.mode, params=params, - model_dir=model_dir) + model_dir=model_dir, + summary_manager=None, + eval_summary_manager=summary_manager.maybe_build_eval_summary_manager( + params=params, model_dir=model_dir + ), + ) keep_training = False except tf.errors.OpError as e: diff --git a/official/vision/utils/object_detection/visualization_utils.py b/official/vision/utils/object_detection/visualization_utils.py index 48159e6ba..3d577b04f 100644 --- a/official/vision/utils/object_detection/visualization_utils.py +++ b/official/vision/utils/object_detection/visualization_utils.py @@ -16,10 +16,10 @@ These functions often receive an image, perform some visualization on the image. The functions do not return a value, instead they modify the image itself. - """ import collections import functools +from typing import Any, Dict from absl import logging # Set headless-friendly backend. @@ -27,14 +27,15 @@ import matplotlib matplotlib.use('Agg') # pylint: disable=multiple-statements import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top import numpy as np -import PIL.Image as Image -import PIL.ImageColor as ImageColor -import PIL.ImageDraw as ImageDraw -import PIL.ImageFont as ImageFont +from PIL import Image +from PIL import ImageColor +from PIL import ImageDraw +from PIL import ImageFont import six import tensorflow as tf from official.vision.ops import box_ops +from official.vision.ops import preprocess_ops from official.vision.utils.object_detection import shape_utils _TITLE_LEFT_MARGIN = 10 @@ -215,15 +216,24 @@ def draw_bounding_box_on_image(image, text_bottom = bottom + total_display_str_height # Reverse list and print from bottom to top. for display_str in display_str_list[::-1]: - text_width, text_height = font.getsize(display_str) - margin = np.ceil(0.05 * text_height) - draw.rectangle([(left, text_bottom - text_height - 2 * margin), - (left + text_width, text_bottom)], - fill=color) - draw.text((left + margin, text_bottom - text_height - margin), - display_str, - fill='black', - font=font) + try: + text_width, text_height = font.getsize(display_str) + margin = np.ceil(0.05 * text_height) + draw.rectangle( + [ + (left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom), + ], + fill=color, + ) + draw.text( + (left + margin, text_bottom - text_height - margin), + display_str, + fill='black', + font=font, + ) + except ValueError: + pass text_bottom -= text_height - 2 * margin @@ -336,6 +346,95 @@ def _resize_original_image(image, image_shape): return tf.cast(tf.squeeze(image, 0), tf.uint8) +def visualize_outputs( + logs, + task_config, + original_image_spatial_shape=None, + true_image_shape=None, + max_boxes_to_draw=20, + min_score_thresh=0.2, + use_normalized_coordinates=False, +): + """Visualizes the detection outputs. + + It extracts images and predictions from logs and draws visualization on input + images. By default, it requires `detection_boxes`, `detection_classes` and + `detection_scores` in the prediction, and optionally accepts + `detection_keypoints` and `detection_masks`. + + Args: + logs: A dictionaty of log that contains images and predictions. + task_config: A task config. + original_image_spatial_shape: A [N, 2] tensor containing the spatial size of + the original image. + true_image_shape: A [N, 3] tensor containing the spatial size of unpadded + original_image. + max_boxes_to_draw: The maximum number of boxes to draw on an image. Default + 20. + min_score_thresh: The minimum score threshold for visualization. Default + 0.2. + use_normalized_coordinates: Whether to assume boxes and kepoints are in + normalized coordinates (as opposed to absolute coordiantes). Default is + True. + + Returns: + A 4D tensor with predictions (boxes, segments and/or keypoints) drawn on + each image. + """ + images = logs['image'] + boxes = logs['detection_boxes'] + classes = tf.cast(logs['detection_classes'], dtype=tf.int32) + scores = logs['detection_scores'] + num_classes = task_config.model.num_classes + + keypoints = ( + logs['detection_keypoints'] if 'detection_keypoints' in logs else None + ) + instance_masks = ( + logs['detection_masks'] if 'detection_masks' in logs else None + ) + + category_index = {} + for i in range(1, num_classes + 1): + category_index[i] = {'id': i, 'name': str(i)} + + def _denormalize_images(images: tf.Tensor) -> tf.Tensor: + images *= tf.constant( + preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype + ) + images += tf.constant( + preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype + ) + return tf.cast(images, dtype=tf.uint8) + + images = tf.nest.map_structure( + tf.identity, + tf.map_fn( + _denormalize_images, + elems=images, + fn_output_signature=tf.TensorSpec( + shape=images.shape.as_list()[1:], dtype=tf.uint8 + ), + parallel_iterations=32, + ), + ) + + return draw_bounding_boxes_on_image_tensors( + images, + boxes, + classes, + scores, + category_index, + original_image_spatial_shape, + true_image_shape, + instance_masks, + keypoints, + max_boxes_to_draw, + min_score_thresh, + use_normalized_coordinates, + ) + + def draw_bounding_boxes_on_image_tensors(images, boxes, classes, @@ -722,3 +821,39 @@ def add_hist_image_summary(values, bins, name): hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8) tf.compat.v1.summary.image(name, hist_plot) + + +def update_detection_state(step_outputs=None) -> Dict[str, Any]: + """Updates detection state to optionally add input image and predictions.""" + state = {} + if step_outputs: + state['image'] = tf.concat(step_outputs['visualization'][0], axis=0) + state['detection_boxes'] = tf.concat( + step_outputs['visualization'][1]['detection_boxes'], axis=0 + ) + state['detection_classes'] = tf.concat( + step_outputs['visualization'][1]['detection_classes'], axis=0 + ) + state['detection_scores'] = tf.concat( + step_outputs['visualization'][1]['detection_scores'], axis=0 + ) + + if 'detection_kpts' in step_outputs['visualization'][1]: + detection_keypoints = step_outputs['visualization'][1]['detection_kpts'] + elif 'detection_keypoints' in step_outputs['visualization'][1]: + detection_keypoints = step_outputs['visualization'][1][ + 'detection_keypoints' + ] + else: + detection_keypoints = None + + if detection_keypoints: + state['detection_keypoints'] = tf.concat(detection_keypoints, axis=0) + + detection_masks = step_outputs['visualization'][1].get( + 'detection_masks', None + ) + if detection_masks: + state['detection_masks'] = tf.concat(detection_masks, axis=0) + + return state diff --git a/official/vision/utils/summary_manager.py b/official/vision/utils/summary_manager.py new file mode 100644 index 000000000..911f6c9af --- /dev/null +++ b/official/vision/utils/summary_manager.py @@ -0,0 +1,84 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom summary manager utilities.""" +import os +from typing import Any, Callable, Dict, Optional + +import orbit +import tensorflow as tf +from official.core import config_definitions + + +class ImageScalarSummaryManager(orbit.utils.SummaryManager): + """Class of custom summary manager that creates scalar and image summary.""" + + def __init__( + self, + summary_dir: str, + scalar_summary_fn: Callable[..., Any], + image_summary_fn: Optional[Callable[..., Any]], + max_outputs: int = 20, + global_step=None, + ): + """Initializes the `ImageScalarSummaryManager` instance.""" + self._enabled = summary_dir is not None + self._summary_dir = summary_dir + self._scalar_summary_fn = scalar_summary_fn + self._image_summary_fn = image_summary_fn + self._summary_writers = {} + self._max_outputs = max_outputs + + if global_step is None: + self._global_step = tf.summary.experimental.get_step() + else: + self._global_step = global_step + + def _write_summaries( + self, summary_dict: Dict[str, Any], relative_path: str = '' + ): + for name, value in summary_dict.items(): + if isinstance(value, dict): + self._write_summaries( + value, relative_path=os.path.join(relative_path, name) + ) + else: + with self.summary_writer(relative_path).as_default(): + if name.startswith('image/'): + self._image_summary_fn( + name, value, self._global_step, max_outputs=self._max_outputs + ) + else: + self._scalar_summary_fn(name, value, self._global_step) + + +def maybe_build_eval_summary_manager( + params: config_definitions.ExperimentConfig, model_dir: str +) -> Optional[orbit.utils.SummaryManager]: + """Maybe creates a SummaryManager.""" + + if ( + hasattr(params.task, 'allow_image_summary') + and params.task.allow_image_summary + ): + eval_summary_dir = os.path.join( + model_dir, params.trainer.validation_summary_subdir + ) + + return ImageScalarSummaryManager( + eval_summary_dir, + scalar_summary_fn=tf.summary.scalar, + image_summary_fn=tf.summary.image, + ) + return None -- GitLab