From f8d6d99d1d8e9311e4cc062d353f4ee50cea61bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Aug 2020 15:31:53 -0700 Subject: [PATCH] Update orbit.Controller: do not write training summary when input summary_dir is None. PiperOrigin-RevId: 325115012 --- .../resnet/resnet_ctl_imagenet_main.py | 1 + orbit/controller.py | 10 ++-- orbit/controller_test.py | 53 ++++++++++++++++++- 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py b/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py index ca0ccd9fd..7d77c2c2b 100644 --- a/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py +++ b/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py @@ -167,6 +167,7 @@ def run(flags_obj): steps_per_loop=steps_per_loop, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, + summary_dir=flags_obj.model_dir, eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) time_callback.on_train_begin() diff --git a/orbit/controller.py b/orbit/controller.py index 3370e556c..aca78c0ff 100644 --- a/orbit/controller.py +++ b/orbit/controller.py @@ -71,9 +71,11 @@ class Controller: `trainer.train` function will always be enabled. If set, the value should be divisible by steps_per_loop. summary_dir: The directory to restore and write checkpoints and summaries. - If None, it will be set to `checkpoint_manager.directory`. + For example, You can set it to `checkpoint_manager.directory`. + If None, it will not write training summarizes. eval_summary_dir: The directory to write eval summaries. If None, it will - be set to `summary_dir`. + be set to `summary_dir`. If both `summary_dir` and `eval_summary_dir` + are None, it will not write evaluation summarizes. Raises: ValueError: If both `trainer` and `evaluator` are None. @@ -108,9 +110,6 @@ class Controller: self.global_step = global_step self.checkpoint_manager = checkpoint_manager - if summary_dir is None and checkpoint_manager: - summary_dir = checkpoint_manager.directory - if self.trainer is not None: self.step_timer = None self.steps_per_loop = steps_per_loop @@ -118,7 +117,6 @@ class Controller: self.summary_manager = utils.SummaryManager( summary_dir, tf.summary.scalar, global_step=self.global_step) - eval_summary_writer = None if self.evaluator is not None: eval_summary_dir = eval_summary_dir or summary_dir if eval_summary_dir == summary_dir and self.trainer is not None: diff --git a/orbit/controller_test.py b/orbit/controller_test.py index 8472de4fb..d0b5ea9ac 100644 --- a/orbit/controller_test.py +++ b/orbit/controller_test.py @@ -294,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): train_steps=10, eval_steps=2, eval_interval=6) self.assertEqual(test_runner.global_step, 10) + def test_has_checkpoint_no_summaries(self): + test_runner = TestRunner() + # Has checkpoint, but no summary directories. + checkpoint = tf.train.Checkpoint(model=test_runner.model) + checkpoint_manager = tf.train.CheckpointManager( + checkpoint, + self.model_dir, + max_to_keep=None, + step_counter=test_runner.global_step) + test_controller = controller.Controller( + trainer=test_runner, + evaluator=test_runner, + global_step=test_runner.global_step, + checkpoint_manager=checkpoint_manager, + steps_per_loop=2) + test_controller.train_and_evaluate( + train_steps=10, eval_steps=2, eval_interval=6) + self.assertEqual(test_runner.global_step, 10) + + # No summaries are saved. + self.assertEmpty(tf.io.gfile.glob( + os.path.join(checkpoint_manager.directory, "events.*"))) + + def test_has_checkpoint_eval_summary_only(self): + test_runner = TestRunner() + # Has checkpoint, but no summary directories. + checkpoint = tf.train.Checkpoint(model=test_runner.model) + checkpoint_manager = tf.train.CheckpointManager( + checkpoint, + self.model_dir, + max_to_keep=None, + step_counter=test_runner.global_step) + test_controller = controller.Controller( + trainer=test_runner, + evaluator=test_runner, + global_step=test_runner.global_step, + checkpoint_manager=checkpoint_manager, + eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), + steps_per_loop=2) + test_controller.train_and_evaluate( + train_steps=10, eval_steps=2, eval_interval=6) + self.assertEqual(test_runner.global_step, 10) + + # Training summaries are not saved. + self.assertEmpty(tf.io.gfile.glob( + os.path.join(checkpoint_manager.directory, "events.*"))) + # Evaluation summaries are saved. + self.assertNotEmpty(tf.io.gfile.glob( + os.path.join(self.model_dir, "summaries/eval/events.*"))) + @parameterized.named_parameters(("return_numpy", True), ("return_tensor", False)) def test_train_and_evaluate(self, return_numpy): @@ -612,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): evaluator=test_runner, global_step=test_runner.global_step, steps_per_loop=10, - checkpoint_manager=checkpoint_manager) + checkpoint_manager=checkpoint_manager, + summary_dir=self.model_dir) test_controller.train_and_evaluate( train_steps=10, eval_steps=2, eval_interval=5) -- GitLab