提交 2b4fe39d 编写于 作者: A A. Unique TensorFlower

Support async checkpoint in Orbit trainer/controller.

This CL adds a field in Orbit trainer/controller indicating whether async checkpoint is enabled for checkpoint saving. BY default this value is set to False, which is equivalent to the existing behavior.

In addition, a sync barrier is added at the end of training (in controller) to make sure users code won't prematurely access the checkpoint file/state when the async checkpoint saving is still ongoing.

PiperOrigin-RevId: 529300903
上级 6b2ed0df
...@@ -96,6 +96,7 @@ class Controller: ...@@ -96,6 +96,7 @@ class Controller:
# Train related # Train related
steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None, steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None, checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
enable_async_checkpointing: bool = False,
# Summary related # Summary related
summary_interval: Optional[int] = None, summary_interval: Optional[int] = None,
summary_dir: Optional[str] = None, summary_dir: Optional[str] = None,
...@@ -141,6 +142,8 @@ class Controller: ...@@ -141,6 +142,8 @@ class Controller:
the model will be restored from the most recent checkpoint inside this the model will be restored from the most recent checkpoint inside this
`__init__` method. If not provided, the `Controller` will not `__init__` method. If not provided, the `Controller` will not
automatically save to or restore from checkpoints. automatically save to or restore from checkpoints.
enable_async_checkpointing: Optional bool indicating whether to enable
async checkpoint saving.
summary_interval: Step interval for training summaries. Note that this summary_interval: Step interval for training summaries. Note that this
argument only applies to `tf.summary` calls inside the `trainer.train` argument only applies to `tf.summary` calls inside the `trainer.train`
function. Summaries written by the `Controller` (specifically function. Summaries written by the `Controller` (specifically
...@@ -204,6 +207,10 @@ class Controller: ...@@ -204,6 +207,10 @@ class Controller:
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
self._enable_async_checkpoint_saving = enable_async_checkpointing
self._checkpoint_options = tf.train.CheckpointOptions(
enable_async=enable_async_checkpointing
)
if self.trainer is not None: if self.trainer is not None:
self.step_timer = None self.step_timer = None
...@@ -244,6 +251,10 @@ class Controller: ...@@ -244,6 +251,10 @@ class Controller:
`CheckpointManager` was passed to `Controller.__init__`) and summarize `CheckpointManager` was passed to `Controller.__init__`) and summarize
training output (if `summary_dir` is set). training output (if `summary_dir` is set).
When async checkpointing is enabled, a sync is triggered at the end of this
method to make sure any ongoing async checkpoint saving is finished before
returning.
Args: Args:
steps: The global step count to train up to. steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method checkpoint_at_completion: Whether to save a checkpoint when this method
...@@ -264,6 +275,8 @@ class Controller: ...@@ -264,6 +275,8 @@ class Controller:
if checkpoint_at_completion: if checkpoint_at_completion:
self._maybe_save_checkpoint(check_interval=False) self._maybe_save_checkpoint(check_interval=False)
self._sync_on_async_checkpointing()
def evaluate(self, steps: int = -1) -> Optional[runner.Output]: def evaluate(self, steps: int = -1) -> Optional[runner.Output]:
"""Runs evaluation for the given number of steps. """Runs evaluation for the given number of steps.
...@@ -339,6 +352,10 @@ class Controller: ...@@ -339,6 +352,10 @@ class Controller:
In addition, this method will run a final evaluation at the end of the In addition, this method will run a final evaluation at the end of the
training sequence. training sequence.
When async checkpointing is enabled, a sync is triggered at the end of this
method to make sure any ongoing async checkpoint saving is finished before
returning.
Args: Args:
train_steps: The global step count to train up to. train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If -1, this eval_steps: The number of steps to run during an evaluation. If -1, this
...@@ -365,6 +382,7 @@ class Controller: ...@@ -365,6 +382,7 @@ class Controller:
output = self.evaluate(steps=eval_steps) output = self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy() current_step = self.global_step.numpy()
self._maybe_save_checkpoint(check_interval=False) self._maybe_save_checkpoint(check_interval=False)
self._sync_on_async_checkpointing()
return output return output
def evaluate_continuously( def evaluate_continuously(
...@@ -539,6 +557,13 @@ class Controller: ...@@ -539,6 +557,13 @@ class Controller:
f"`{attribute}` is not set. Pass `{attribute}` to " f"`{attribute}` is not set. Pass `{attribute}` to "
f"`Controller.__init__` before calling `{for_method}()`.") f"`Controller.__init__` before calling `{for_method}()`.")
def _sync_on_async_checkpointing(self):
"""Force to wait for the async checkpoint saving (if any) to finish."""
# pylint: disable=protected-access
if self.checkpoint_manager:
logging.info("Sync on async checkpoint saving.")
self.checkpoint_manager.sync()
class StepTimer: class StepTimer:
"""Utility class for measuring steps/second.""" """Utility class for measuring steps/second."""
......
...@@ -294,7 +294,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -294,7 +294,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10) self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving):
test_runner = TestRunner() test_runner = TestRunner()
# Has checkpoint, but no summary directories. # Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint = tf.train.Checkpoint(model=test_runner.model)
...@@ -308,6 +312,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -308,6 +312,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner, evaluator=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
steps_per_loop=2) steps_per_loop=2)
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
...@@ -317,7 +322,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -317,7 +322,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(tf.io.gfile.glob( self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*"))) os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_eval_summary_only(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner() test_runner = TestRunner()
# Has checkpoint, but no summary directories. # Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint = tf.train.Checkpoint(model=test_runner.model)
...@@ -331,6 +342,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -331,6 +342,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner, evaluator=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2) steps_per_loop=2)
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
...@@ -344,7 +356,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -344,7 +356,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEmpty(tf.io.gfile.glob( self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*"))) os.path.join(self.model_dir, "summaries/eval/events.*")))
def test_restore_from_most_recent_checkpoint(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_restore_from_most_recent_checkpoint(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner() test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
...@@ -357,6 +375,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -357,6 +375,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner, trainer=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=5) steps_per_loop=5)
test_controller.train(20) test_controller.train(20)
...@@ -364,9 +383,15 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -364,9 +383,15 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
restored_path = test_controller.restore_checkpoint() restored_path = test_controller.restore_checkpoint()
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1]) self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
@parameterized.named_parameters(("return_numpy", True), @parameterized.named_parameters(
("return_tensor", False)) ("return_numpy_sync_checkpoint_saving", True, False),
def test_train_and_evaluate(self, return_numpy): ("return_numpy_async_checkpoint_saving", True, True),
("return_tensor_sync_checkpoint_saving", False, False),
("return_tensor_async_checkpoint_saving", False, True),
)
def test_train_and_evaluate(
self, return_numpy, enable_async_checkpoint_saving
):
test_runner = TestRunner(return_numpy=return_numpy) test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
...@@ -384,6 +409,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -384,6 +409,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
...@@ -403,7 +429,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -403,7 +429,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
summaries_with_matching_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_train_only(self, enable_async_checkpoint_saving):
test_runner = TestRunner() test_runner = TestRunner()
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
...@@ -420,6 +450,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -420,6 +450,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
) )
test_controller.train(steps=10) test_controller.train(steps=10)
...@@ -497,7 +528,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -497,7 +528,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager=checkpoint_manager) checkpoint_manager=checkpoint_manager)
test_controller.evaluate() test_controller.evaluate()
def test_already_trained_model(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_already_trained_model(self, enable_async_checkpoint_saving):
test_runner = TestRunner() test_runner = TestRunner()
test_runner.global_step.assign(10) test_runner.global_step.assign(10)
...@@ -513,7 +548,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -513,7 +548,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner, trainer=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
steps_per_loop=2, steps_per_loop=2,
checkpoint_manager=checkpoint_manager) checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving)
# `global_step` is already `train_steps`. # `global_step` is already `train_steps`.
test_controller.train(steps=10) test_controller.train(steps=10)
...@@ -533,7 +569,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -533,7 +569,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2, steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2, summary_interval=2,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager
) )
test_controller.train(steps=10) test_controller.train(steps=10)
...@@ -594,6 +630,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -594,6 +630,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
interval = min(train_steps - self.global_step.numpy(), eval_interval) interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False) self.train(steps=num_steps, checkpoint_at_completion=False)
self._sync_on_async_checkpointing()
self.evaluate(steps=eval_steps) self.evaluate(steps=eval_steps)
# Early stop condition. # Early stop condition.
if test_runner.eval_loss.result() < 0.1: if test_runner.eval_loss.result() < 0.1:
...@@ -672,7 +709,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -672,7 +709,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
def test_eval_and_checkpoint_interval(self): @parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving):
test_runner = TestRunner() test_runner = TestRunner()
checkpoint = tf.train.Checkpoint( checkpoint = tf.train.Checkpoint(
...@@ -689,6 +730,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -689,6 +730,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
global_step=test_runner.global_step, global_step=test_runner.global_step,
steps_per_loop=10, steps_per_loop=10,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
summary_dir=self.model_dir) summary_dir=self.model_dir)
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5) train_steps=10, eval_steps=2, eval_interval=5)
...@@ -803,7 +845,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -803,7 +845,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner, trainer=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
steps_per_loop=steps_per_loop_fn, steps_per_loop=steps_per_loop_fn,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager
) )
test_controller.train(steps=10) test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10) self.assertEqual(test_runner.global_step, 10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册