diff --git a/orbit/controller.py b/orbit/controller.py index 1f277231fc8a549693e51347ece815a2c363e6f1..9edaf92b42cbac05b77dced553552b54ce86420e 100644 --- a/orbit/controller.py +++ b/orbit/controller.py @@ -96,6 +96,7 @@ class Controller: # Train related steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None, checkpoint_manager: Optional[tf.train.CheckpointManager] = None, + enable_async_checkpointing: bool = False, # Summary related summary_interval: Optional[int] = None, summary_dir: Optional[str] = None, @@ -141,6 +142,8 @@ class Controller: the model will be restored from the most recent checkpoint inside this `__init__` method. If not provided, the `Controller` will not automatically save to or restore from checkpoints. + enable_async_checkpointing: Optional bool indicating whether to enable + async checkpoint saving. summary_interval: Step interval for training summaries. Note that this argument only applies to `tf.summary` calls inside the `trainer.train` function. Summaries written by the `Controller` (specifically @@ -204,6 +207,10 @@ class Controller: self.global_step = global_step self.checkpoint_manager = checkpoint_manager + self._enable_async_checkpoint_saving = enable_async_checkpointing + self._checkpoint_options = tf.train.CheckpointOptions( + enable_async=enable_async_checkpointing + ) if self.trainer is not None: self.step_timer = None @@ -244,6 +251,10 @@ class Controller: `CheckpointManager` was passed to `Controller.__init__`) and summarize training output (if `summary_dir` is set). + When async checkpointing is enabled, a sync is triggered at the end of this + method to make sure any ongoing async checkpoint saving is finished before + returning. + Args: steps: The global step count to train up to. checkpoint_at_completion: Whether to save a checkpoint when this method @@ -264,6 +275,8 @@ class Controller: if checkpoint_at_completion: self._maybe_save_checkpoint(check_interval=False) + self._sync_on_async_checkpointing() + def evaluate(self, steps: int = -1) -> Optional[runner.Output]: """Runs evaluation for the given number of steps. @@ -339,6 +352,10 @@ class Controller: In addition, this method will run a final evaluation at the end of the training sequence. + When async checkpointing is enabled, a sync is triggered at the end of this + method to make sure any ongoing async checkpoint saving is finished before + returning. + Args: train_steps: The global step count to train up to. eval_steps: The number of steps to run during an evaluation. If -1, this @@ -365,6 +382,7 @@ class Controller: output = self.evaluate(steps=eval_steps) current_step = self.global_step.numpy() self._maybe_save_checkpoint(check_interval=False) + self._sync_on_async_checkpointing() return output def evaluate_continuously( @@ -539,6 +557,13 @@ class Controller: f"`{attribute}` is not set. Pass `{attribute}` to " f"`Controller.__init__` before calling `{for_method}()`.") + def _sync_on_async_checkpointing(self): + """Force to wait for the async checkpoint saving (if any) to finish.""" + # pylint: disable=protected-access + if self.checkpoint_manager: + logging.info("Sync on async checkpoint saving.") + self.checkpoint_manager.sync() + class StepTimer: """Utility class for measuring steps/second.""" diff --git a/orbit/controller_test.py b/orbit/controller_test.py index 9077f312bf2dbf65b4b23dd64d14563e7f8a6e5d..8ffad830bd17fff07b32ff895b524c39a2620e7a 100644 --- a/orbit/controller_test.py +++ b/orbit/controller_test.py @@ -294,7 +294,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): train_steps=10, eval_steps=2, eval_interval=6) self.assertEqual(test_runner.global_step, 10) - def test_has_checkpoint_no_summaries(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving): test_runner = TestRunner() # Has checkpoint, but no summary directories. checkpoint = tf.train.Checkpoint(model=test_runner.model) @@ -308,6 +312,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): evaluator=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, steps_per_loop=2) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) @@ -317,7 +322,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): self.assertEmpty(tf.io.gfile.glob( os.path.join(checkpoint_manager.directory, "events.*"))) - def test_has_checkpoint_eval_summary_only(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_has_checkpoint_eval_summary_only( + self, enable_async_checkpoint_saving + ): test_runner = TestRunner() # Has checkpoint, but no summary directories. checkpoint = tf.train.Checkpoint(model=test_runner.model) @@ -331,6 +342,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): evaluator=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), steps_per_loop=2) test_controller.train_and_evaluate( @@ -344,7 +356,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): self.assertNotEmpty(tf.io.gfile.glob( os.path.join(self.model_dir, "summaries/eval/events.*"))) - def test_restore_from_most_recent_checkpoint(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_restore_from_most_recent_checkpoint( + self, enable_async_checkpoint_saving + ): test_runner = TestRunner() checkpoint = tf.train.Checkpoint(model=test_runner.model) checkpoint_manager = tf.train.CheckpointManager( @@ -357,6 +375,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): trainer=test_runner, global_step=test_runner.global_step, checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), steps_per_loop=5) test_controller.train(20) @@ -364,9 +383,15 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): restored_path = test_controller.restore_checkpoint() self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1]) - @parameterized.named_parameters(("return_numpy", True), - ("return_tensor", False)) - def test_train_and_evaluate(self, return_numpy): + @parameterized.named_parameters( + ("return_numpy_sync_checkpoint_saving", True, False), + ("return_numpy_async_checkpoint_saving", True, True), + ("return_tensor_sync_checkpoint_saving", False, False), + ("return_tensor_async_checkpoint_saving", False, True), + ) + def test_train_and_evaluate( + self, return_numpy, enable_async_checkpoint_saving + ): test_runner = TestRunner(return_numpy=return_numpy) checkpoint = tf.train.Checkpoint( @@ -384,6 +409,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) @@ -403,7 +429,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): summaries_with_matching_keyword( "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) - def test_train_only(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_train_only(self, enable_async_checkpoint_saving): test_runner = TestRunner() checkpoint = tf.train.Checkpoint( @@ -420,6 +450,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), ) test_controller.train(steps=10) @@ -497,7 +528,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): checkpoint_manager=checkpoint_manager) test_controller.evaluate() - def test_already_trained_model(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_already_trained_model(self, enable_async_checkpoint_saving): test_runner = TestRunner() test_runner.global_step.assign(10) @@ -513,7 +548,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): trainer=test_runner, global_step=test_runner.global_step, steps_per_loop=2, - checkpoint_manager=checkpoint_manager) + checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving) # `global_step` is already `train_steps`. test_controller.train(steps=10) @@ -533,7 +569,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): steps_per_loop=2, summary_dir=os.path.join(self.model_dir, "summaries/train"), summary_interval=2, - checkpoint_manager=checkpoint_manager, + checkpoint_manager=checkpoint_manager ) test_controller.train(steps=10) @@ -594,6 +630,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): interval = min(train_steps - self.global_step.numpy(), eval_interval) num_steps = self.global_step.numpy() + interval self.train(steps=num_steps, checkpoint_at_completion=False) + self._sync_on_async_checkpointing() self.evaluate(steps=eval_steps) # Early stop condition. if test_runner.eval_loss.result() < 0.1: @@ -672,7 +709,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=6) - def test_eval_and_checkpoint_interval(self): + @parameterized.named_parameters( + ("_sync_checkpoint_saving", False), + ("_async_checkpoint_saving", True) + ) + def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving): test_runner = TestRunner() checkpoint = tf.train.Checkpoint( @@ -689,6 +730,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): global_step=test_runner.global_step, steps_per_loop=10, checkpoint_manager=checkpoint_manager, + enable_async_checkpointing=enable_async_checkpoint_saving, summary_dir=self.model_dir) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=5) @@ -803,7 +845,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): trainer=test_runner, global_step=test_runner.global_step, steps_per_loop=steps_per_loop_fn, - checkpoint_manager=checkpoint_manager, + checkpoint_manager=checkpoint_manager ) test_controller.train(steps=10) self.assertEqual(test_runner.global_step, 10)