train_lib.py 13.6 KB
Newer Older
A
A. Unique TensorFlower 已提交
1
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
H
Hongkun Yu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hongkun Yu 已提交
14

H
Hongkun Yu 已提交
15
"""TFM common training driver library."""
H
Hongkun Yu 已提交
16
# pytype: disable=attribute-error
H
Hongkun Yu 已提交
17
import os
F
Fan Yang 已提交
18 19
import tempfile
from typing import Any, List, Mapping, Optional, Tuple
H
Hongkun Yu 已提交
20 21

# Import libraries
A
Abdullah Rashwan 已提交
22

H
Hongkun Yu 已提交
23 24 25 26
from absl import logging
import orbit
import tensorflow as tf

A
Abdullah Rashwan 已提交
27
from official.core import actions
H
Hongkun Yu 已提交
28
from official.core import base_task
29
from official.core import base_trainer
30
from official.core import config_definitions
L
Le Hou 已提交
31
from official.core import train_utils
H
Hongkun Yu 已提交
32

33
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
A
A. Unique TensorFlower 已提交
34 35


36 37 38 39 40 41 42 43 44 45
class OrbitExperimentRunner:
  """Runs experiment with Orbit training loop.

  The default experiment runner for model garden experiments. User can
  customize the experiment pipeline by subclassing this class and replacing
  components or functions.

  For example, an experiment runner with customized checkpoint manager:

  ```python
Y
Yeqing Li 已提交
46
  class MyExpRunnerWithExporter(OrbitExperimentRunner):
47
    def _maybe_build_checkpoint_manager(sefl):
Y
Yeqing Li 已提交
48
      # Replaces the default CheckpointManger with a customized one.
49 50
      return MyCheckpointManager(*args)

Y
Yeqing Li 已提交
51 52
  # In user code, instead of the orginal
  # `OrbitExperimentRunner(..).run(mode)`, now user can do:
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  MyExpRunnerWithExporter(**needed_kwargs).run(mode)
  ```

  Similar override can be done to other components.
  """

  def __init__(
      self,
      distribution_strategy: tf.distribute.Strategy,
      task: base_task.Task,
      mode: str,
      params: config_definitions.ExperimentConfig,
      model_dir: str,
      run_post_eval: bool = False,
      save_summary: bool = True,
      train_actions: Optional[List[orbit.Action]] = None,
      eval_actions: Optional[List[orbit.Action]] = None,
      trainer: Optional[base_trainer.Trainer] = None,
F
Fan Yang 已提交
71 72 73
      controller_cls=orbit.Controller,
      summary_manager: Optional[orbit.utils.SummaryManager] = None,
      eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
74
      enable_async_checkpointing: bool = False,
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
  ):
    """Constructor.

    Args:
      distribution_strategy: A distribution strategy.
      task: A Task instance.
      mode: A 'str', specifying the mode. Can be 'train', 'eval',
        'train_and_eval' or 'continuous_eval'.
      params: ExperimentConfig instance.
      model_dir: A 'str', a path to store model checkpoints and summaries.
      run_post_eval: Whether to run post eval once after training, metrics logs
        are returned.
      save_summary: Whether to save train and validation summary.
      train_actions: Optional list of Orbit train actions.
      eval_actions: Optional list of Orbit eval actions.
      trainer: the base_trainer.Trainer instance. It should be created within
        the strategy.scope().
      controller_cls: The controller class to manage the train and eval process.
        Must be a orbit.Controller subclass.
F
Fan Yang 已提交
94 95 96 97
      summary_manager: Instance of the summary manager to override default
        summary manager.
      eval_summary_manager: Instance of the eval summary manager to override
        default eval summary manager.
98 99
      enable_async_checkpointing: Optional boolean indicating whether to enable
        async checkpoint saving.
100 101 102 103 104 105 106 107 108 109 110 111 112
    """
    self.strategy = distribution_strategy or tf.distribute.get_strategy()
    self._params = params
    self._model_dir = model_dir
    self._mode = mode
    self._run_post_eval = run_post_eval

    self._trainer = trainer or self._build_trainer(
        task,
        train='train' in mode,
        evaluate=('eval' in mode) or run_post_eval)
    assert self.trainer is not None
    self._checkpoint_manager = self._maybe_build_checkpoint_manager()
F
Fan Yang 已提交
113 114
    self._summary_manager = summary_manager
    self._eval_summary_manager = eval_summary_manager
115 116 117 118 119 120
    self._controller = self._build_controller(
        trainer=self.trainer if 'train' in mode else None,
        evaluator=self.trainer,
        save_summary=save_summary,
        train_actions=train_actions,
        eval_actions=eval_actions,
121 122
        controller_cls=controller_cls,
        enable_async_checkpointing=enable_async_checkpointing)
123 124 125

  @property
  def params(self) -> config_definitions.ExperimentConfig:
Y
Yeqing Li 已提交
126
    """The whole experiment parameters object."""
127 128 129 130
    return self._params

  @property
  def model_dir(self) -> str:
Y
Yeqing Li 已提交
131
    """Path to the model folder, which stores checkpoints, params, log, etc."""
132 133 134 135
    return self._model_dir

  @property
  def trainer(self) -> base_trainer.Trainer:
Y
Yeqing Li 已提交
136
    """The underlying Orbit Trainer object."""
137 138 139
    return self._trainer

  @property
A
A. Unique TensorFlower 已提交
140
  def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]:
Y
Yeqing Li 已提交
141
    """The CheckpointManager that stores the checkpoints in a train job."""
142 143 144 145
    return self._checkpoint_manager

  @property
  def controller(self) -> orbit.Controller:
Y
Yeqing Li 已提交
146
    """The Orbit controller object."""
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    return self._controller

  def _build_trainer(self, task: base_task.Task, train: bool,
                     evaluate: bool) -> base_trainer.Trainer:
    """Create trainer."""
    with self.strategy.scope():
      trainer = train_utils.create_trainer(
          self.params,
          task,
          train=train,
          evaluate=evaluate,
          checkpoint_exporter=self._build_best_checkpoint_exporter())
    return trainer

  def _build_best_checkpoint_exporter(self):
    return maybe_create_best_ckpt_exporter(self.params, self.model_dir)

  def _maybe_build_checkpoint_manager(
      self) -> Optional[tf.train.CheckpointManager]:
    """Maybe create a CheckpointManager."""
    assert self.trainer is not None
    if self.trainer.checkpoint:
      if self.model_dir is None:
        raise ValueError('model_dir must be specified, but got None')
F
Fan Yang 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183

      if (not self.strategy) or self.strategy.extended.should_checkpoint:
        ckpt_path = self.model_dir
        max_to_keep = self.params.trainer.max_to_keep
      else:
        # In multi worker training we need every worker to save checkpoint,
        # because variables can trigger synchronization on read and
        # synchronization needs all workers to participate. To avoid workers
        # overriding each other we save to a temporary directory on non-chief
        # workers.
        ckpt_path = tempfile.mkdtemp()
        max_to_keep = 1

184 185
      checkpoint_manager = tf.train.CheckpointManager(
          self.trainer.checkpoint,
F
Fan Yang 已提交
186 187
          directory=ckpt_path,
          max_to_keep=max_to_keep,
188 189 190 191 192 193 194
          step_counter=self.trainer.global_step,
          checkpoint_interval=self.params.trainer.checkpoint_interval,
          init_fn=self.trainer.initialize)
    else:
      checkpoint_manager = None
    return checkpoint_manager

195 196 197 198 199 200 201 202 203 204
  def _build_controller(
      self,
      trainer,
      evaluator,
      save_summary: bool = True,
      train_actions: Optional[List[orbit.Action]] = None,
      eval_actions: Optional[List[orbit.Action]] = None,
      controller_cls=orbit.Controller,
      enable_async_checkpointing: bool = False,
  ) -> orbit.Controller:
205 206 207
    """Builds a Orbit controler."""
    train_actions = [] if not train_actions else train_actions
    if trainer:
A
A. Unique TensorFlower 已提交
208 209
      checkpoint_manager = self.checkpoint_manager
      assert checkpoint_manager, 'Checkpoint manager required but undefined.'
210 211 212 213
      train_actions += actions.get_train_actions(
          self.params,
          trainer,
          self.model_dir,
A
A. Unique TensorFlower 已提交
214 215
          checkpoint_manager=checkpoint_manager,
      )
216 217 218 219 220 221

    eval_actions = [] if not eval_actions else eval_actions
    if evaluator:
      eval_actions += actions.get_eval_actions(self.params, evaluator,
                                               self.model_dir)

F
Fan Yang 已提交
222 223 224 225 226 227 228
    if save_summary:
      eval_summary_dir = os.path.join(
          self.model_dir, self.params.trainer.validation_summary_subdir
      )
    else:
      eval_summary_dir = None

229 230 231 232 233 234 235
    controller = controller_cls(
        strategy=self.strategy,
        trainer=trainer,
        evaluator=evaluator,
        global_step=self.trainer.global_step,
        steps_per_loop=self.params.trainer.steps_per_loop,
        checkpoint_manager=self.checkpoint_manager,
236
        enable_async_checkpointing=enable_async_checkpointing,
F
Fan Yang 已提交
237 238 239 240 241 242 243
        summary_dir=os.path.join(self.model_dir, 'train')
        if (save_summary)
        else None,
        eval_summary_dir=eval_summary_dir,
        summary_interval=self.params.trainer.summary_interval
        if (save_summary)
        else None,
244
        train_actions=train_actions,
F
Fan Yang 已提交
245
        eval_actions=eval_actions,
F
Fan Yang 已提交
246 247 248 249 250 251
        summary_manager=self._summary_manager
        if hasattr(self, '_summary_manager')
        else None,
        eval_summary_manager=self._eval_summary_manager
        if hasattr(self, '_eval_summary_manager')
        else None,
F
Fan Yang 已提交
252
    )
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
    return controller

  def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
    """Run experiments by mode.

    Returns:
      A 2-tuple of (model, eval_logs).
        model: `tf.keras.Model` instance.
        eval_logs: returns eval metrics logs when run_post_eval is set to True,
          otherwise, returns {}.
    """
    mode = self._mode
    params = self.params
    logging.info('Starts to execute mode: %s', mode)
    with self.strategy.scope():
      if mode == 'train' or mode == 'train_and_post_eval':
        self.controller.train(steps=params.trainer.train_steps)
      elif mode == 'train_and_eval':
        self.controller.train_and_evaluate(
            train_steps=params.trainer.train_steps,
            eval_steps=params.trainer.validation_steps,
            eval_interval=params.trainer.validation_interval)
      elif mode == 'eval':
        self.controller.evaluate(steps=params.trainer.validation_steps)
      elif mode == 'continuous_eval':

        def timeout_fn():
          if self.trainer.global_step.numpy() >= params.trainer.train_steps:
            return True
          return False

        self.controller.evaluate_continuously(
            steps=params.trainer.validation_steps,
            timeout=params.trainer.continuous_eval_timeout,
            timeout_fn=timeout_fn)
      else:
        raise NotImplementedError('The mode is not implemented: %s' % mode)

    num_params = train_utils.try_count_params(self.trainer.model)
    if num_params is not None:
      logging.info('Number of trainable params in model: %f Millions.',
                   num_params / 10.**6)

    flops = train_utils.try_count_flops(self.trainer.model)
    if flops is not None:
      logging.info('FLOPs (multi-adds) in model: %f Billions.',
                   flops / 10.**9 / 2)

    if self._run_post_eval or mode == 'train_and_post_eval':
      with self.strategy.scope():
        return self.trainer.model, self.controller.evaluate(
            steps=params.trainer.validation_steps)
    else:
      return self.trainer.model, {}


309 310 311 312 313 314 315 316
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
A
Abdullah Rashwan 已提交
317 318
    train_actions: Optional[List[orbit.Action]] = None,
    eval_actions: Optional[List[orbit.Action]] = None,
Y
Yeqing Li 已提交
319
    trainer: Optional[base_trainer.Trainer] = None,
F
Fan Yang 已提交
320 321 322
    controller_cls=orbit.Controller,
    summary_manager: Optional[orbit.utils.SummaryManager] = None,
    eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
323
    enable_async_checkpointing: bool = False,
324
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
H
Hongkun Yu 已提交
325 326 327 328 329 330 331 332 333 334 335 336
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
A
Abdullah Rashwan 已提交
337 338
    train_actions: Optional list of Orbit train actions.
    eval_actions: Optional list of Orbit eval actions.
339 340
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().
Y
Yeqing Li 已提交
341 342
    controller_cls: The controller class to manage the train and eval process.
      Must be a orbit.Controller subclass.
F
Fan Yang 已提交
343 344 345 346
    summary_manager: Instance of the summary manager to override default summary
      manager.
    eval_summary_manager: Instance of the eval summary manager to override
      default eval summary manager.
347 348
    enable_async_checkpointing: Optional boolean indicating whether to enable
        async checkpoint saving.
H
Hongkun Yu 已提交
349 350

  Returns:
351 352 353 354
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
H
Hongkun Yu 已提交
355
  """
356 357 358 359 360 361 362 363
  runner = OrbitExperimentRunner(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=mode,
      params=params,
      model_dir=model_dir,
      run_post_eval=run_post_eval,
      save_summary=save_summary,
A
Abdullah Rashwan 已提交
364
      train_actions=train_actions,
365 366 367
      eval_actions=eval_actions,
      trainer=trainer,
      controller_cls=controller_cls,
F
Fan Yang 已提交
368 369
      summary_manager=summary_manager,
      eval_summary_manager=eval_summary_manager,
370
      enable_async_checkpointing=enable_async_checkpointing,
371 372
  )
  return runner.run()