提交 f8d6d99d 编写于 作者: A A. Unique TensorFlower

Update orbit.Controller: do not write training summary when input summary_dir is None.

PiperOrigin-RevId: 325115012
上级 b3677ae2
......@@ -167,6 +167,7 @@ def run(flags_obj):
steps_per_loop=steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_interval=summary_interval,
summary_dir=flags_obj.model_dir,
eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))
time_callback.on_train_begin()
......
......@@ -71,9 +71,11 @@ class Controller:
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
For example, You can set it to `checkpoint_manager.directory`.
If None, it will not write training summarizes.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
be set to `summary_dir`. If both `summary_dir` and `eval_summary_dir`
are None, it will not write evaluation summarizes.
Raises:
ValueError: If both `trainer` and `evaluator` are None.
......@@ -108,9 +110,6 @@ class Controller:
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
if summary_dir is None and checkpoint_manager:
summary_dir = checkpoint_manager.directory
if self.trainer is not None:
self.step_timer = None
self.steps_per_loop = steps_per_loop
......@@ -118,7 +117,6 @@ class Controller:
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
eval_summary_writer = None
if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None:
......
......@@ -294,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# No summaries are saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Training summaries are not saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
# Evaluation summaries are saved.
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
@parameterized.named_parameters(("return_numpy", True),
("return_tensor", False))
def test_train_and_evaluate(self, return_numpy):
......@@ -612,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager)
checkpoint_manager=checkpoint_manager,
summary_dir=self.model_dir)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册