提交 221cfa1e 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Chains of two or more "Identity" nodes weren't being spliced correctly.

Change: 150043496
上级 dc17f76f
......@@ -307,10 +307,10 @@ def remove_training_nodes(input_graph):
del new_node.input[:]
for full_input_name in input_before_removal:
input_name = re.sub(r"^\^", "", full_input_name)
if input_name in names_to_splice:
new_node.input.append(names_to_splice[input_name])
else:
new_node.input.append(full_input_name)
while input_name in names_to_splice:
full_input_name = names_to_splice[input_name]
input_name = re.sub(r"^\^", "", full_input_name)
new_node.input.append(full_input_name)
nodes_after_splicing.append(new_node)
output_graph = graph_pb2.GraphDef()
......
......@@ -318,6 +318,29 @@ class DeviceFunctionsTest(test.TestCase):
output = graph_util.remove_training_nodes(graph_def)
self.assertProtoEquals(expected_output, output)
def testRemoveIdentityChains(self):
"""Check that chains of Identity nodes are correctly pruned.
Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
in the nodes A and D, where A inputs D.
"""
graph_def = graph_pb2.GraphDef()
graph_def.node.extend([
self.create_node_def("Aop", "A", ["B"]), self.create_node_def(
"Identity", "B", ["C"]), self.create_node_def(
"Identity", "C", ["D"]), self.create_node_def("Dop", "D", [])
])
expected_graph_def = graph_pb2.GraphDef()
expected_graph_def.node.extend([
self.create_node_def("Aop", "A", ["D"]), self.create_node_def(
"Dop", "D", [])
])
self.assertProtoEquals(expected_graph_def,
graph_util.remove_training_nodes(graph_def))
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册