提交 6e2a1d5e 编写于 作者: A A. Unique TensorFlower

Enable async checkpoint by default in Tensorflow model garden.

PiperOrigin-RevId: 532860554
上级 8f17df9a
...@@ -71,6 +71,7 @@ class OrbitExperimentRunner: ...@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls=orbit.Controller, controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None, summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
): ):
"""Constructor. """Constructor.
...@@ -94,6 +95,8 @@ class OrbitExperimentRunner: ...@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager. summary manager.
eval_summary_manager: Instance of the eval summary manager to override eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager. default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
""" """
self.strategy = distribution_strategy or tf.distribute.get_strategy() self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params self._params = params
...@@ -115,7 +118,8 @@ class OrbitExperimentRunner: ...@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary=save_summary, save_summary=save_summary,
train_actions=train_actions, train_actions=train_actions,
eval_actions=eval_actions, eval_actions=eval_actions,
controller_cls=controller_cls) controller_cls=controller_cls,
enable_async_checkpointing=enable_async_checkpointing)
@property @property
def params(self) -> config_definitions.ExperimentConfig: def params(self) -> config_definitions.ExperimentConfig:
...@@ -188,13 +192,16 @@ class OrbitExperimentRunner: ...@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager = None checkpoint_manager = None
return checkpoint_manager return checkpoint_manager
def _build_controller(self, def _build_controller(
trainer, self,
evaluator, trainer,
save_summary: bool = True, evaluator,
train_actions: Optional[List[orbit.Action]] = None, save_summary: bool = True,
eval_actions: Optional[List[orbit.Action]] = None, train_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller: eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller,
enable_async_checkpointing: bool = False,
) -> orbit.Controller:
"""Builds a Orbit controler.""" """Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions train_actions = [] if not train_actions else train_actions
if trainer: if trainer:
...@@ -223,6 +230,7 @@ class OrbitExperimentRunner: ...@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step=self.trainer.global_step, global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop, steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager, checkpoint_manager=self.checkpoint_manager,
enable_async_checkpointing=enable_async_checkpointing,
summary_dir=os.path.join(self.model_dir, 'train') summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary) if (save_summary)
else None, else None,
...@@ -309,6 +317,7 @@ def run_experiment( ...@@ -309,6 +317,7 @@ def run_experiment(
controller_cls=orbit.Controller, controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None, summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -332,6 +341,8 @@ def run_experiment( ...@@ -332,6 +341,8 @@ def run_experiment(
manager. manager.
eval_summary_manager: Instance of the eval summary manager to override eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager. default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -353,5 +364,6 @@ def run_experiment( ...@@ -353,5 +364,6 @@ def run_experiment(
controller_cls=controller_cls, controller_cls=controller_cls,
summary_manager=summary_manager, summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager, eval_summary_manager=eval_summary_manager,
enable_async_checkpointing=enable_async_checkpointing,
) )
return runner.run() return runner.run()
...@@ -38,6 +38,11 @@ flags.DEFINE_integer( ...@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default=None, default=None,
help='The number of total training steps for the pretraining job.') help='The number of total training steps for the pretraining job.')
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir): def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption.""" """Runs experiment and tries to reconnect when encounting a preemption."""
...@@ -62,7 +67,8 @@ def _run_experiment_with_preemption_recovery(params, model_dir): ...@@ -62,7 +67,8 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
task=task, task=task,
mode=FLAGS.mode, mode=FLAGS.mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir,
enable_async_checkpointing=FLAGS.enable_async_checkpointing)
keep_training = False keep_training = False
except tf.errors.OpError as e: except tf.errors.OpError as e:
......
...@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager ...@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir): def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption.""" """Runs experiment and tries to reconnect when encounting a preemption."""
...@@ -60,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir): ...@@ -60,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager( eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
params=params, model_dir=model_dir params=params, model_dir=model_dir
), ),
enable_async_checkpointing=FLAGS.enable_async_checkpointing,
) )
keep_training = False keep_training = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册