From 439d515a059740ba926ec2442299300fe52aa101 Mon Sep 17 00:00:00 2001 From: Frederick Liu Date: Wed, 22 Dec 2021 13:17:34 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 417875109 --- official/core/train_utils_test.py | 56 +++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/official/core/train_utils_test.py b/official/core/train_utils_test.py index 2010736aa..42344853d 100644 --- a/official/core/train_utils_test.py +++ b/official/core/train_utils_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for official.core.train_utils.""" +import json import os import pprint @@ -138,5 +139,60 @@ class TrainUtilsTest(tf.test.TestCase): self.assertEqual(params_from_obj.trainer.validation_steps, 11) +class BestCheckpointExporterTest(tf.test.TestCase): + + def test_maybe_export(self): + model_dir = self.create_tempdir().full_path + best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1') + metric_name = 'test_metric|metric_1' + exporter = train_utils.BestCheckpointExporter( + model_dir, metric_name, 'higher') + v = tf.Variable(1.0) + checkpoint = tf.train.Checkpoint(v=v) + ret = exporter.maybe_export_checkpoint( + checkpoint, {'test_metric': {'metric_1': 5.0}}, 100) + with self.subTest(name='Successful first save.'): + self.assertEqual(ret, True) + v_2 = tf.Variable(2.0) + checkpoint_2 = tf.train.Checkpoint(v=v_2) + checkpoint_2.restore(best_ckpt_path) + self.assertEqual(v_2.numpy(), 1.0) + + v = tf.Variable(3.0) + checkpoint = tf.train.Checkpoint(v=v) + ret = exporter.maybe_export_checkpoint( + checkpoint, {'test_metric': {'metric_1': 6.0}}, 200) + with self.subTest(name='Successful better metic save.'): + self.assertEqual(ret, True) + v_2 = tf.Variable(2.0) + checkpoint_2 = tf.train.Checkpoint(v=v_2) + checkpoint_2.restore(best_ckpt_path) + self.assertEqual(v_2.numpy(), 3.0) + + v = tf.Variable(5.0) + checkpoint = tf.train.Checkpoint(v=v) + ret = exporter.maybe_export_checkpoint( + checkpoint, {'test_metric': {'metric_1': 1.0}}, 300) + with self.subTest(name='Worse metic no save.'): + self.assertEqual(ret, False) + v_2 = tf.Variable(2.0) + checkpoint_2 = tf.train.Checkpoint(v=v_2) + checkpoint_2.restore(best_ckpt_path) + self.assertEqual(v_2.numpy(), 3.0) + + def test_export_best_eval_metric(self): + model_dir = self.create_tempdir().full_path + metric_name = 'test_metric|metric_1' + exporter = train_utils.BestCheckpointExporter(model_dir, metric_name, + 'higher') + exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100) + with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'), + 'rb') as reader: + metric = json.loads(reader.read()) + self.assertAllEqual( + metric, + {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0}) + + if __name__ == '__main__': tf.test.main() -- GitLab