diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 18f1e9059651533e9586217e06c3316acce32b36..b6a456833987aaebb236e9eaf12533557ce803dd 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -80,7 +80,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, print("Input saver file '" + input_saver + "' does not exist!") return -1 - if not tf.gfile.Glob(input_checkpoint): + # 'input_checkpoint' may be a prefix if we're using Saver V2 format + if not tf.train.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index cc4eaa95b3cc68b3a8c8cf21981e377f42515b5d..79745ff1d0d33e6ebd5f30e87120cf59a4401def 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -27,7 +27,7 @@ from tensorflow.python.tools import freeze_graph class FreezeGraphTest(test_util.TensorFlowTestCase): - def testFreezeGraph(self): + def _testFreezeGraph(self, saver_write_version): checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") checkpoint_state_name = "checkpoint_state" @@ -44,9 +44,9 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): sess.run(init) output = sess.run(output_node) self.assertNear(2.0, output, 0.00001) - saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) - saver.save(sess, checkpoint_prefix, global_step=0, - latest_filename=checkpoint_state_name) + saver = tf.train.Saver(write_version=saver_write_version) + checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0, + latest_filename=checkpoint_state_name) tf.train.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) # We save out the graph to disk, and then call the const conversion @@ -54,7 +54,6 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) input_saver_def_path = "" input_binary = False - input_checkpoint_path = checkpoint_prefix + "-0" output_node_names = "output_node" restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" @@ -62,7 +61,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): clear_devices = False freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, - input_binary, input_checkpoint_path, + input_binary, checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_graph_path, clear_devices, "") @@ -84,5 +83,11 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): output = sess.run(output_node) self.assertNear(2.0, output, 0.00001) + def testFreezeGraphV1(self): + self._testFreezeGraph(tf.train.SaverDef.V1) + + def testFreezeGraphV2(self): + self._testFreezeGraph(tf.train.SaverDef.V2) + if __name__ == "__main__": tf.test.main()