提交 6e2a1d5e 编写于 作者: A A. Unique TensorFlower

Enable async checkpoint by default in Tensorflow model garden.

PiperOrigin-RevId: 532860554
上级 8f17df9a
......@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
):
"""Constructor.
......@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
......@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
controller_cls=controller_cls,
enable_async_checkpointing=enable_async_checkpointing)
@property
def params(self) -> config_definitions.ExperimentConfig:
......@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
def _build_controller(
self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller,
enable_async_checkpointing: bool = False,
) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions
if trainer:
......@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
enable_async_checkpointing=enable_async_checkpointing,
summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary)
else None,
......@@ -309,6 +317,7 @@ def run_experiment(
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
......@@ -332,6 +341,8 @@ def run_experiment(
manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns:
A 2-tuple of (model, eval_logs).
......@@ -353,5 +364,6 @@ def run_experiment(
controller_cls=controller_cls,
summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager,
enable_async_checkpointing=enable_async_checkpointing,
)
return runner.run()
......@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default=None,
help='The number of total training steps for the pretraining job.')
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption."""
......@@ -62,7 +67,8 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
model_dir=model_dir,
enable_async_checkpointing=FLAGS.enable_async_checkpointing)
keep_training = False
except tf.errors.OpError as e:
......
......@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS = flags.FLAGS
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption."""
......@@ -60,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
params=params, model_dir=model_dir
),
enable_async_checkpointing=FLAGS.enable_async_checkpointing,
)
keep_training = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册