提交 90dedf26 编写于 作者: R Ruoxin Sang 提交者: A. Unique TensorFlower

Allow `steps_per_loop` in Controller to be passed as a callable.

PiperOrigin-RevId: 466412169
上级 db19ab9b
......@@ -94,7 +94,7 @@ class Controller:
train_actions: Optional[Iterable[Action]] = None,
eval_actions: Optional[Iterable[Action]] = None,
# Train related
steps_per_loop: Optional[int] = None,
steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# Summary related
summary_interval: Optional[int] = None,
......@@ -130,8 +130,11 @@ class Controller:
output of `trainer.train`.
eval_actions: Optional `orbit.Action`s to call after each evaluation.
These will be called with the output of `evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
steps_per_loop: Optional integer to indicate the number of steps to run in
each inner loop of training (passed as the `num_steps` parameter of
`trainer.train`). It can be also a callable which takes the current
global step value as input and returns the number of steps to run as
output.
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
provided and there are checkpoints in the associated model directory,
the model will be restored from the most recent checkpoint inside this
......@@ -152,7 +155,7 @@ class Controller:
Raises:
ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If `steps_per_loop` is not a positive integer.
ValueError: If `steps_per_loop` is not a positive integer or a callable.
ValueError: If `summary_interval` is not a positive integer or is not
divisible by `steps_per_loop`.
"""
......@@ -163,15 +166,18 @@ class Controller:
if steps_per_loop is None:
raise ValueError(
"`steps_per_loop` is required when `trainer` is provided.")
elif not isinstance(steps_per_loop, int) or steps_per_loop < 1:
elif not callable(steps_per_loop) and (
not isinstance(steps_per_loop, int) or steps_per_loop < 1):
raise ValueError(
f"`steps_per_loop` ({steps_per_loop}) must be a positive integer.")
f"`steps_per_loop` ({steps_per_loop}) must be a positive integer "
"or a callable.")
if summary_interval is not None:
if summary_interval <= 0:
raise ValueError(
f"`summary_interval` ({summary_interval}) must be larger than 0.")
elif summary_interval % steps_per_loop != 0:
elif not callable(steps_per_loop) and (summary_interval % steps_per_loop
!= 0):
raise ValueError(
f"`summary interval` ({summary_interval}) must be a multiple "
f"of `steps_per_loop` ({steps_per_loop}).")
......@@ -192,10 +198,10 @@ class Controller:
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
self._steps_per_loop = steps_per_loop
if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
......@@ -316,9 +322,6 @@ class Controller:
results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
self._require("trainer", for_method="train_and_evaluate")
self._require("evaluator", for_method="train_and_evaluate")
......@@ -410,6 +413,13 @@ class Controller:
self._require("checkpoint_manager", for_method="save_checkpoint")
self._maybe_save_checkpoint(check_interval=False)
@property
def steps_per_loop(self):
"""Returns current steps_per_loop value in a training loop."""
if callable(self._steps_per_loop):
return self._steps_per_loop(self.global_step.numpy())
return self._steps_per_loop
def _train_n_steps(self, num_steps: int):
"""Runs training for `num_steps` steps.
......
......@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn("eval_loss", output)
self.assertGreaterEqual(output["eval_loss"], 0)
def test_step_per_loop_callable(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
def steps_per_loop_fn(global_step):
if global_step > 4:
return 4
return 2
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=steps_per_loop_fn,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册