提交 4992ebfb 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 512771693
上级 fec44b03
......@@ -289,6 +289,9 @@ class TaskConfig(base_config.Config):
# DEPRECATED b/264611883
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
# Whether to show image summary. Useful to visualize model predictions. Only
# work for vision tasks.
allow_image_summary: bool = False
@dataclasses.dataclass
......
......@@ -68,7 +68,9 @@ class OrbitExperimentRunner:
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
):
"""Constructor.
......@@ -88,6 +90,10 @@ class OrbitExperimentRunner:
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
summary_manager: Instance of the summary manager to override default
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
......@@ -101,6 +107,8 @@ class OrbitExperimentRunner:
evaluate=('eval' in mode) or run_post_eval)
assert self.trainer is not None
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
self._summary_manager = summary_manager
self._eval_summary_manager = eval_summary_manager
self._controller = self._build_controller(
trainer=self.trainer if 'train' in mode else None,
evaluator=self.trainer,
......@@ -201,6 +209,13 @@ class OrbitExperimentRunner:
eval_actions += actions.get_eval_actions(self.params, evaluator,
self.model_dir)
if save_summary:
eval_summary_dir = os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir
)
else:
eval_summary_dir = None
controller = controller_cls(
strategy=self.strategy,
trainer=trainer,
......@@ -208,15 +223,18 @@ class OrbitExperimentRunner:
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
summary_dir=os.path.join(self.model_dir, 'train') if
(save_summary) else None,
eval_summary_dir=os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=self.params.trainer.summary_interval if
(save_summary) else None,
summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary)
else None,
eval_summary_dir=eval_summary_dir,
summary_interval=self.params.trainer.summary_interval
if (save_summary)
else None,
train_actions=train_actions,
eval_actions=eval_actions)
eval_actions=eval_actions,
summary_manager=self._summary_manager,
eval_summary_manager=self._eval_summary_manager,
)
return controller
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
......@@ -284,7 +302,9 @@ def run_experiment(
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
......@@ -304,6 +324,10 @@ def run_experiment(
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
summary_manager: Instance of the summary manager to override default summary
manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
Returns:
A 2-tuple of (model, eval_logs).
......@@ -323,5 +347,7 @@ def run_experiment(
eval_actions=eval_actions,
trainer=trainer,
controller_cls=controller_cls,
summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager,
)
return runner.run()
......@@ -32,6 +32,7 @@ from official.vision.evaluation import coco_evaluator
from official.vision.losses import focal_loss
from official.vision.losses import loss_utils
from official.vision.modeling import factory
from official.vision.utils.object_detection import visualization_utils
@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
......@@ -264,15 +265,20 @@ class RetinaNetTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
if not training:
if self.task_config.validation_data.tfds_name and self.task_config.annotation_file:
if (
self.task_config.validation_data.tfds_name
and self.task_config.annotation_file
):
raise ValueError(
"Can't evaluate using annotation file when TFDS is used.")
"Can't evaluate using annotation file when TFDS is used."
)
if self._task_config.use_coco_metrics:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=False,
per_category_metrics=self.task_config.per_category_metrics,
max_num_eval_detections=self.task_config.max_num_eval_detections)
max_num_eval_detections=self.task_config.max_num_eval_detections,
)
if self._task_config.use_wod_metrics:
# To use Waymo open dataset metrics, please install one of the pip
# package `waymo-open-dataset-tf-*` from
......@@ -402,6 +408,14 @@ class RetinaNetTask(base_task.Task):
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
if (
hasattr(self.task_config, 'allow_image_summary')
and self.task_config.allow_image_summary
):
logs.update(
{'visualization': (tf.cast(features, dtype=tf.float32), outputs)}
)
return logs
def aggregate_logs(self, state=None, step_outputs=None):
......@@ -418,7 +432,12 @@ class RetinaNetTask(base_task.Task):
if state is None:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state = True
state = {}
# Update detection state for writing summary if there are artifacts for
# visualization.
if 'visualization' in step_outputs:
state.update(visualization_utils.update_detection_state(step_outputs))
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
......@@ -427,4 +446,12 @@ class RetinaNetTask(base_task.Task):
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
# Add visualization for summary.
if 'image' in aggregated_logs:
validation_outputs = visualization_utils.visualize_outputs(
logs=aggregated_logs, task_config=self.task_config
)
logs.update({'image/validation_outputs': validation_outputs})
return logs
......@@ -26,9 +26,9 @@ from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.vision import registry_imports
# pylint: enable=unused-import
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.utils import summary_manager
FLAGS = flags.FLAGS
......@@ -53,7 +53,12 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
model_dir=model_dir,
summary_manager=None,
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
params=params, model_dir=model_dir
),
)
keep_training = False
except tf.errors.OpError as e:
......
......@@ -16,10 +16,10 @@
These functions often receive an image, perform some visualization on the image.
The functions do not return a value, instead they modify the image itself.
"""
import collections
import functools
from typing import Any, Dict
from absl import logging
# Set headless-friendly backend.
......@@ -27,14 +27,15 @@ import matplotlib
matplotlib.use('Agg') # pylint: disable=multiple-statements
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
import numpy as np
import PIL.Image as Image
import PIL.ImageColor as ImageColor
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
import six
import tensorflow as tf
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.vision.utils.object_detection import shape_utils
_TITLE_LEFT_MARGIN = 10
......@@ -215,15 +216,24 @@ def draw_bounding_box_on_image(image,
text_bottom = bottom + total_display_str_height
# Reverse list and print from bottom to top.
for display_str in display_str_list[::-1]:
text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height)
draw.rectangle([(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom)],
fill=color)
draw.text((left + margin, text_bottom - text_height - margin),
display_str,
fill='black',
font=font)
try:
text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height)
draw.rectangle(
[
(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom),
],
fill=color,
)
draw.text(
(left + margin, text_bottom - text_height - margin),
display_str,
fill='black',
font=font,
)
except ValueError:
pass
text_bottom -= text_height - 2 * margin
......@@ -336,6 +346,95 @@ def _resize_original_image(image, image_shape):
return tf.cast(tf.squeeze(image, 0), tf.uint8)
def visualize_outputs(
logs,
task_config,
original_image_spatial_shape=None,
true_image_shape=None,
max_boxes_to_draw=20,
min_score_thresh=0.2,
use_normalized_coordinates=False,
):
"""Visualizes the detection outputs.
It extracts images and predictions from logs and draws visualization on input
images. By default, it requires `detection_boxes`, `detection_classes` and
`detection_scores` in the prediction, and optionally accepts
`detection_keypoints` and `detection_masks`.
Args:
logs: A dictionaty of log that contains images and predictions.
task_config: A task config.
original_image_spatial_shape: A [N, 2] tensor containing the spatial size of
the original image.
true_image_shape: A [N, 3] tensor containing the spatial size of unpadded
original_image.
max_boxes_to_draw: The maximum number of boxes to draw on an image. Default
20.
min_score_thresh: The minimum score threshold for visualization. Default
0.2.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes). Default is
True.
Returns:
A 4D tensor with predictions (boxes, segments and/or keypoints) drawn on
each image.
"""
images = logs['image']
boxes = logs['detection_boxes']
classes = tf.cast(logs['detection_classes'], dtype=tf.int32)
scores = logs['detection_scores']
num_classes = task_config.model.num_classes
keypoints = (
logs['detection_keypoints'] if 'detection_keypoints' in logs else None
)
instance_masks = (
logs['detection_masks'] if 'detection_masks' in logs else None
)
category_index = {}
for i in range(1, num_classes + 1):
category_index[i] = {'id': i, 'name': str(i)}
def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
images *= tf.constant(
preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype
)
images += tf.constant(
preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype
)
return tf.cast(images, dtype=tf.uint8)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
_denormalize_images,
elems=images,
fn_output_signature=tf.TensorSpec(
shape=images.shape.as_list()[1:], dtype=tf.uint8
),
parallel_iterations=32,
),
)
return draw_bounding_boxes_on_image_tensors(
images,
boxes,
classes,
scores,
category_index,
original_image_spatial_shape,
true_image_shape,
instance_masks,
keypoints,
max_boxes_to_draw,
min_score_thresh,
use_normalized_coordinates,
)
def draw_bounding_boxes_on_image_tensors(images,
boxes,
classes,
......@@ -722,3 +821,39 @@ def add_hist_image_summary(values, bins, name):
hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8)
tf.compat.v1.summary.image(name, hist_plot)
def update_detection_state(step_outputs=None) -> Dict[str, Any]:
"""Updates detection state to optionally add input image and predictions."""
state = {}
if step_outputs:
state['image'] = tf.concat(step_outputs['visualization'][0], axis=0)
state['detection_boxes'] = tf.concat(
step_outputs['visualization'][1]['detection_boxes'], axis=0
)
state['detection_classes'] = tf.concat(
step_outputs['visualization'][1]['detection_classes'], axis=0
)
state['detection_scores'] = tf.concat(
step_outputs['visualization'][1]['detection_scores'], axis=0
)
if 'detection_kpts' in step_outputs['visualization'][1]:
detection_keypoints = step_outputs['visualization'][1]['detection_kpts']
elif 'detection_keypoints' in step_outputs['visualization'][1]:
detection_keypoints = step_outputs['visualization'][1][
'detection_keypoints'
]
else:
detection_keypoints = None
if detection_keypoints:
state['detection_keypoints'] = tf.concat(detection_keypoints, axis=0)
detection_masks = step_outputs['visualization'][1].get(
'detection_masks', None
)
if detection_masks:
state['detection_masks'] = tf.concat(detection_masks, axis=0)
return state
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom summary manager utilities."""
import os
from typing import Any, Callable, Dict, Optional
import orbit
import tensorflow as tf
from official.core import config_definitions
class ImageScalarSummaryManager(orbit.utils.SummaryManager):
"""Class of custom summary manager that creates scalar and image summary."""
def __init__(
self,
summary_dir: str,
scalar_summary_fn: Callable[..., Any],
image_summary_fn: Optional[Callable[..., Any]],
max_outputs: int = 20,
global_step=None,
):
"""Initializes the `ImageScalarSummaryManager` instance."""
self._enabled = summary_dir is not None
self._summary_dir = summary_dir
self._scalar_summary_fn = scalar_summary_fn
self._image_summary_fn = image_summary_fn
self._summary_writers = {}
self._max_outputs = max_outputs
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
def _write_summaries(
self, summary_dict: Dict[str, Any], relative_path: str = ''
):
for name, value in summary_dict.items():
if isinstance(value, dict):
self._write_summaries(
value, relative_path=os.path.join(relative_path, name)
)
else:
with self.summary_writer(relative_path).as_default():
if name.startswith('image/'):
self._image_summary_fn(
name, value, self._global_step, max_outputs=self._max_outputs
)
else:
self._scalar_summary_fn(name, value, self._global_step)
def maybe_build_eval_summary_manager(
params: config_definitions.ExperimentConfig, model_dir: str
) -> Optional[orbit.utils.SummaryManager]:
"""Maybe creates a SummaryManager."""
if (
hasattr(params.task, 'allow_image_summary')
and params.task.allow_image_summary
):
eval_summary_dir = os.path.join(
model_dir, params.trainer.validation_summary_subdir
)
return ImageScalarSummaryManager(
eval_summary_dir,
scalar_summary_fn=tf.summary.scalar,
image_summary_fn=tf.summary.image,
)
return None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册