train_lib.py 5.0 KB
Newer Older
H
Hongkun Yu 已提交
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
H
Hongkun Yu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hongkun Yu 已提交
14

H
Hongkun Yu 已提交
15
"""TFM common training driver library."""
H
Hongkun Yu 已提交
16
# pytype: disable=attribute-error
H
Hongkun Yu 已提交
17
import os
A
Abdullah Rashwan 已提交
18
from typing import Any, Mapping, Optional, Tuple
H
Hongkun Yu 已提交
19 20

# Import libraries
A
Abdullah Rashwan 已提交
21

H
Hongkun Yu 已提交
22 23 24 25
from absl import logging
import orbit
import tensorflow as tf

A
Abdullah Rashwan 已提交
26
from official.core import actions
H
Hongkun Yu 已提交
27
from official.core import base_task
28
from official.core import base_trainer
29
from official.core import config_definitions
L
Le Hou 已提交
30
from official.core import train_utils
H
Hongkun Yu 已提交
31

32
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
A
A. Unique TensorFlower 已提交
33 34


35 36 37 38 39 40 41 42 43 44
def run_experiment(
    distribution_strategy: tf.distribute.Strategy,
    task: base_task.Task,
    mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    save_summary: bool = True,
    trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
H
Hongkun Yu 已提交
45 46 47 48 49 50 51 52 53 54 55 56
  """Runs train/eval configured by the experiment params.

  Args:
    distribution_strategy: A distribution distribution_strategy.
    task: A Task instance.
    mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
      or 'continuous_eval'.
    params: ExperimentConfig instance.
    model_dir: A 'str', a path to store model checkpoints and summaries.
    run_post_eval: Whether to run post eval once after training, metrics logs
      are returned.
    save_summary: Whether to save train and validation summary.
57 58
    trainer: the base_trainer.Trainer instance. It should be created within the
      strategy.scope().
H
Hongkun Yu 已提交
59 60

  Returns:
61 62 63 64
    A 2-tuple of (model, eval_logs).
      model: `tf.keras.Model` instance.
      eval_logs: returns eval metrics logs when run_post_eval is set to True,
        otherwise, returns {}.
H
Hongkun Yu 已提交
65 66 67
  """

  with distribution_strategy.scope():
68 69 70 71 72 73 74 75
    if not trainer:
      trainer = train_utils.create_trainer(
          params,
          task,
          train='train' in mode,
          evaluate=('eval' in mode) or run_post_eval,
          checkpoint_exporter=maybe_create_best_ckpt_exporter(
              params, model_dir))
H
Hongkun Yu 已提交
76 77 78 79 80 81 82 83 84

  if trainer.checkpoint:
    checkpoint_manager = tf.train.CheckpointManager(
        trainer.checkpoint,
        directory=model_dir,
        max_to_keep=params.trainer.max_to_keep,
        step_counter=trainer.global_step,
        checkpoint_interval=params.trainer.checkpoint_interval,
        init_fn=trainer.initialize)
85 86
    # Adds recovery handling.
    trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager)
H
Hongkun Yu 已提交
87 88 89 90
  else:
    checkpoint_manager = None

  controller = orbit.Controller(
91
      strategy=distribution_strategy,
H
Hongkun Yu 已提交
92 93 94 95 96
      trainer=trainer if 'train' in mode else None,
      evaluator=trainer,
      global_step=trainer.global_step,
      steps_per_loop=params.trainer.steps_per_loop,
      checkpoint_manager=checkpoint_manager,
97
      summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
A
A. Unique TensorFlower 已提交
98 99
      eval_summary_dir=os.path.join(model_dir,
                                    params.trainer.validation_summary_subdir) if
100 101
      (save_summary) else None,
      summary_interval=params.trainer.summary_interval if
A
Abdullah Rashwan 已提交
102 103
      (save_summary) else None,
      eval_actions=actions.get_eval_actions(params, trainer, model_dir))
H
Hongkun Yu 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116

  logging.info('Starts to execute mode: %s', mode)
  with distribution_strategy.scope():
    if mode == 'train':
      controller.train(steps=params.trainer.train_steps)
    elif mode == 'train_and_eval':
      controller.train_and_evaluate(
          train_steps=params.trainer.train_steps,
          eval_steps=params.trainer.validation_steps,
          eval_interval=params.trainer.validation_interval)
    elif mode == 'eval':
      controller.evaluate(steps=params.trainer.validation_steps)
    elif mode == 'continuous_eval':
117

H
Hongkun Yu 已提交
118 119 120 121
      def timeout_fn():
        if trainer.global_step.numpy() >= params.trainer.train_steps:
          return True
        return False
122

H
Hongkun Yu 已提交
123 124
      controller.evaluate_continuously(
          steps=params.trainer.validation_steps,
H
Hongkun Yu 已提交
125 126
          timeout=params.trainer.continuous_eval_timeout,
          timeout_fn=timeout_fn)
H
Hongkun Yu 已提交
127 128 129
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)

L
Le Hou 已提交
130 131
  num_params = train_utils.try_count_params(trainer.model)
  if num_params is not None:
A
A. Unique TensorFlower 已提交
132
    logging.info('Number of trainable params in model: %f Millions.',
L
Le Hou 已提交
133 134
                 num_params / 10.**6)

H
Hongkun Yu 已提交
135 136
  if run_post_eval:
    with distribution_strategy.scope():
137
      return trainer.model, trainer.evaluate(
H
Hongkun Yu 已提交
138 139
          tf.convert_to_tensor(params.trainer.validation_steps))
  else:
140
    return trainer.model, {}