提交 439d515a 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 417875109
上级 e6337e30
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for official.core.train_utils."""
import json
import os
import pprint
......@@ -138,5 +139,60 @@ class TrainUtilsTest(tf.test.TestCase):
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
class BestCheckpointExporterTest(tf.test.TestCase):
def test_maybe_export(self):
model_dir = self.create_tempdir().full_path
best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(
model_dir, metric_name, 'higher')
v = tf.Variable(1.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
with self.subTest(name='Successful first save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 1.0)
v = tf.Variable(3.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
with self.subTest(name='Successful better metic save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
v = tf.Variable(5.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
with self.subTest(name='Worse metic no save.'):
self.assertEqual(ret, False)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
def test_export_best_eval_metric(self):
model_dir = self.create_tempdir().full_path
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
'higher')
exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
'rb') as reader:
metric = json.loads(reader.read())
self.assertAllEqual(
metric,
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册