提交 7d8ce2ee 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix freeze_graph.py and freeze_graph_test.py to work with Saver V1 and V2.

Change: 136743848
上级 8a0e81d8
......@@ -80,7 +80,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
print("Input saver file '" + input_saver + "' does not exist!")
return -1
if not tf.gfile.Glob(input_checkpoint):
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not tf.train.checkpoint_exists(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
......
......@@ -27,7 +27,7 @@ from tensorflow.python.tools import freeze_graph
class FreezeGraphTest(test_util.TensorFlowTestCase):
def testFreezeGraph(self):
def _testFreezeGraph(self, saver_write_version):
checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
......@@ -44,8 +44,8 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
sess.run(init)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
saver.save(sess, checkpoint_prefix, global_step=0,
saver = tf.train.Saver(write_version=saver_write_version)
checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0,
latest_filename=checkpoint_state_name)
tf.train.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
......@@ -54,7 +54,6 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = checkpoint_prefix + "-0"
output_node_names = "output_node"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
......@@ -62,7 +61,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, input_checkpoint_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
......@@ -84,5 +83,11 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
def testFreezeGraphV1(self):
self._testFreezeGraph(tf.train.SaverDef.V1)
def testFreezeGraphV2(self):
self._testFreezeGraph(tf.train.SaverDef.V2)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册