提交 2b4fe39d 编写于 作者: A A. Unique TensorFlower

Support async checkpoint in Orbit trainer/controller.

This CL adds a field in Orbit trainer/controller indicating whether async checkpoint is enabled for checkpoint saving. BY default this value is set to False, which is equivalent to the existing behavior.

In addition, a sync barrier is added at the end of training (in controller) to make sure users code won't prematurely access the checkpoint file/state when the async checkpoint saving is still ongoing.

PiperOrigin-RevId: 529300903
上级 6b2ed0df
......@@ -96,6 +96,7 @@ class Controller:
# Train related
steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
enable_async_checkpointing: bool = False,
# Summary related
summary_interval: Optional[int] = None,
summary_dir: Optional[str] = None,
......@@ -141,6 +142,8 @@ class Controller:
the model will be restored from the most recent checkpoint inside this
`__init__` method. If not provided, the `Controller` will not
automatically save to or restore from checkpoints.
enable_async_checkpointing: Optional bool indicating whether to enable
async checkpoint saving.
summary_interval: Step interval for training summaries. Note that this
argument only applies to `tf.summary` calls inside the `trainer.train`
function. Summaries written by the `Controller` (specifically
......@@ -204,6 +207,10 @@ class Controller:
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
self._enable_async_checkpoint_saving = enable_async_checkpointing
self._checkpoint_options = tf.train.CheckpointOptions(
enable_async=enable_async_checkpointing
)
if self.trainer is not None:
self.step_timer = None
......@@ -244,6 +251,10 @@ class Controller:
`CheckpointManager` was passed to `Controller.__init__`) and summarize
training output (if `summary_dir` is set).
When async checkpointing is enabled, a sync is triggered at the end of this
method to make sure any ongoing async checkpoint saving is finished before
returning.
Args:
steps: The global step count to train up to.
checkpoint_at_completion: Whether to save a checkpoint when this method
......@@ -264,6 +275,8 @@ class Controller:
if checkpoint_at_completion:
self._maybe_save_checkpoint(check_interval=False)
self._sync_on_async_checkpointing()
def evaluate(self, steps: int = -1) -> Optional[runner.Output]:
"""Runs evaluation for the given number of steps.
......@@ -339,6 +352,10 @@ class Controller:
In addition, this method will run a final evaluation at the end of the
training sequence.
When async checkpointing is enabled, a sync is triggered at the end of this
method to make sure any ongoing async checkpoint saving is finished before
returning.
Args:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If -1, this
......@@ -365,6 +382,7 @@ class Controller:
output = self.evaluate(steps=eval_steps)
current_step = self.global_step.numpy()
self._maybe_save_checkpoint(check_interval=False)
self._sync_on_async_checkpointing()
return output
def evaluate_continuously(
......@@ -539,6 +557,13 @@ class Controller:
f"`{attribute}` is not set. Pass `{attribute}` to "
f"`Controller.__init__` before calling `{for_method}()`.")
def _sync_on_async_checkpointing(self):
"""Force to wait for the async checkpoint saving (if any) to finish."""
# pylint: disable=protected-access
if self.checkpoint_manager:
logging.info("Sync on async checkpoint saving.")
self.checkpoint_manager.sync()
class StepTimer:
"""Utility class for measuring steps/second."""
......
......@@ -294,7 +294,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
......@@ -308,6 +312,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
......@@ -317,7 +322,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_eval_summary_only(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
......@@ -331,6 +342,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
......@@ -344,7 +356,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
def test_restore_from_most_recent_checkpoint(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_restore_from_most_recent_checkpoint(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
......@@ -357,6 +375,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=5)
test_controller.train(20)
......@@ -364,9 +383,15 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
restored_path = test_controller.restore_checkpoint()
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
@parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
def test_train_and_evaluate(self, return_numpy):
@parameterized.named_parameters(
("return_numpy_sync_checkpoint_saving", True, False),
("return_numpy_async_checkpoint_saving", True, True),
("return_tensor_sync_checkpoint_saving", False, False),
("return_tensor_async_checkpoint_saving", False, True),
)
def test_train_and_evaluate(
self, return_numpy, enable_async_checkpoint_saving
):
test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint(
......@@ -384,6 +409,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
......@@ -403,7 +429,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_train_only(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
......@@ -420,6 +450,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(steps=10)
......@@ -497,7 +528,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
def test_already_trained_model(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_already_trained_model(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
test_runner.global_step.assign(10)
......@@ -513,7 +548,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
......@@ -533,7 +569,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager,
checkpoint_manager=checkpoint_manager
)
test_controller.train(steps=10)
......@@ -594,6 +630,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self._sync_on_async_checkpointing()
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
......@@ -672,7 +709,11 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
def test_eval_and_checkpoint_interval(self):
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
......@@ -689,6 +730,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
summary_dir=self.model_dir)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
......@@ -803,7 +845,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=steps_per_loop_fn,
checkpoint_manager=checkpoint_manager,
checkpoint_manager=checkpoint_manager
)
test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册