# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Custom summary manager utilities.""" import os from typing import Any, Callable, Dict, Optional import orbit import tensorflow as tf from official.core import config_definitions class ImageScalarSummaryManager(orbit.utils.SummaryManager): """Class of custom summary manager that creates scalar and image summary.""" def __init__( self, summary_dir: str, scalar_summary_fn: Callable[..., Any], image_summary_fn: Optional[Callable[..., Any]], max_outputs: int = 20, global_step=None, ): """Initializes the `ImageScalarSummaryManager` instance.""" self._enabled = summary_dir is not None self._summary_dir = summary_dir self._scalar_summary_fn = scalar_summary_fn self._image_summary_fn = image_summary_fn self._summary_writers = {} self._max_outputs = max_outputs if global_step is None: self._global_step = tf.summary.experimental.get_step() else: self._global_step = global_step def _write_summaries( self, summary_dict: Dict[str, Any], relative_path: str = '' ): for name, value in summary_dict.items(): if isinstance(value, dict): self._write_summaries( value, relative_path=os.path.join(relative_path, name) ) else: with self.summary_writer(relative_path).as_default(): if name.startswith('image/'): self._image_summary_fn( name, value, self._global_step, max_outputs=self._max_outputs ) else: self._scalar_summary_fn(name, value, self._global_step) def maybe_build_eval_summary_manager( params: config_definitions.ExperimentConfig, model_dir: str ) -> Optional[orbit.utils.SummaryManager]: """Maybe creates a SummaryManager.""" if ( hasattr(params.task, 'allow_image_summary') and params.task.allow_image_summary ): eval_summary_dir = os.path.join( model_dir, params.trainer.validation_summary_subdir ) return ImageScalarSummaryManager( eval_summary_dir, scalar_summary_fn=tf.summary.scalar, image_summary_fn=tf.summary.image, ) return None