提交 4ecd2a70 编写于 作者: S Sherry Moore 提交者: TensorFlower Gardener

Added unit test for max_to_keep being None.

Change: 115516426
上级 77da168d
......@@ -37,6 +37,14 @@ from tensorflow.python.framework import function
from tensorflow.python.platform import gfile
def _TestDir(test_name):
test_dir = os.path.join(tf.test.get_temp_dir(), test_name)
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
gfile.MakeDirs(test_dir)
return test_dir
class SaverTest(tf.test.TestCase):
def testBasics(self):
......@@ -349,12 +357,7 @@ class SaveRestoreShardedTest(tf.test.TestCase):
class MaxToKeepTest(tf.test.TestCase):
def testNonSharded(self):
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
try:
gfile.DeleteRecursively(save_dir)
except OSError:
pass # Ignore
gfile.MakeDirs(save_dir)
save_dir = _TestDir("max_to_keep_non_sharded")
with self.test_session() as sess:
v = tf.Variable(10.0, name="v")
......@@ -456,12 +459,7 @@ class MaxToKeepTest(tf.test.TestCase):
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
def testSharded(self):
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded")
try:
gfile.DeleteRecursively(save_dir)
except OSError:
pass # Ignore
gfile.MakeDirs(save_dir)
save_dir = _TestDir("max_to_keep_sharded")
with tf.Session(
target="",
......@@ -495,17 +493,39 @@ class MaxToKeepTest(tf.test.TestCase):
self.assertEqual(2, len(gfile.Glob(s3)))
self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
def testNoMaxToKeep(self):
save_dir = _TestDir("no_max_to_keep")
save_dir2 = _TestDir("max_to_keep_0")
with self.test_session() as sess:
v = tf.Variable(10.0, name="v")
tf.initialize_all_variables().run()
# Test max_to_keep being None.
save = tf.train.Saver({"v": v}, max_to_keep=None)
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(gfile.Exists(s1))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(gfile.Exists(s2))
# Test max_to_keep being 0.
save2 = tf.train.Saver({"v": v}, max_to_keep=0)
self.assertEqual([], save2.last_checkpoints)
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(gfile.Exists(s1))
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(gfile.Exists(s2))
class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
def testNonSharded(self):
save_dir = os.path.join(self.get_temp_dir(),
"keep_checkpoint_every_n_hours")
try:
gfile.DeleteRecursively(save_dir)
except OSError:
pass # Ignore
gfile.MakeDirs(save_dir)
save_dir = _TestDir("keep_checkpoint_every_n_hours")
with self.test_session() as sess:
v = tf.Variable([10.0], name="v")
......@@ -685,15 +705,8 @@ class LatestCheckpointWithRelativePaths(tf.test.TestCase):
class CheckpointStateTest(tf.test.TestCase):
def _TestDir(self, test_name):
test_dir = os.path.join(self.get_temp_dir(), test_name)
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
gfile.MakeDirs(test_dir)
return test_dir
def testAbsPath(self):
save_dir = self._TestDir("abs_paths")
save_dir = _TestDir("abs_paths")
abs_path = os.path.join(save_dir, "model-0")
ckpt = tf.train.generate_checkpoint_state_proto(save_dir, abs_path)
self.assertEqual(ckpt.model_checkpoint_path, abs_path)
......@@ -712,7 +725,7 @@ class CheckpointStateTest(tf.test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
def testAllModelCheckpointPaths(self):
save_dir = self._TestDir("all_models_test")
save_dir = _TestDir("all_models_test")
abs_path = os.path.join(save_dir, "model-0")
for paths in [None, [], ["model-2"]]:
ckpt = tf.train.generate_checkpoint_state_proto(
......@@ -726,7 +739,7 @@ class CheckpointStateTest(tf.test.TestCase):
self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
def testUpdateCheckpointState(self):
save_dir = self._TestDir("update_checkpoint_state")
save_dir = _TestDir("update_checkpoint_state")
os.chdir(save_dir)
# Make a temporary train directory.
train_dir = "train"
......@@ -746,15 +759,8 @@ class CheckpointStateTest(tf.test.TestCase):
class MetaGraphTest(tf.test.TestCase):
def _TestDir(self, test_name):
test_dir = os.path.join(self.get_temp_dir(), test_name)
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
gfile.MakeDirs(test_dir)
return test_dir
def testAddCollectionDef(self):
test_dir = self._TestDir("good_collection")
test_dir = _TestDir("good_collection")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
# Creates a graph.
......@@ -819,7 +825,7 @@ class MetaGraphTest(tf.test.TestCase):
self.assertEqual(len(meta_graph_def.collection_def), 0)
def _testMultiSaverCollectionSave(self):
test_dir = self._TestDir("saver_collection")
test_dir = _TestDir("saver_collection")
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
......@@ -894,7 +900,7 @@ class MetaGraphTest(tf.test.TestCase):
self._testMultiSaverCollectionRestore()
def testBinaryAndTextFormat(self):
test_dir = self._TestDir("binary_and_text")
test_dir = _TestDir("binary_and_text")
filename = os.path.join(test_dir, "metafile")
with self.test_session(graph=tf.Graph()):
# Creates a graph.
......@@ -924,7 +930,7 @@ class MetaGraphTest(tf.test.TestCase):
tf.train.import_meta_graph(filename)
def testSliceVariable(self):
test_dir = self._TestDir("slice_saver")
test_dir = _TestDir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
v1 = tf.Variable([20.0], name="v1")
......@@ -946,7 +952,7 @@ class MetaGraphTest(tf.test.TestCase):
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
def _testGraphExtensionSave(self):
test_dir = self._TestDir("graph_extension")
test_dir = _TestDir("graph_extension")
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
with self.test_session(graph=tf.Graph()) as sess:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册