提交 13e73084 编写于 作者: J jiangjiajun

add trick method for tf2fluid

上级 804beb33
...@@ -15,56 +15,8 @@ ...@@ -15,56 +15,8 @@
# TODO useless node remove # TODO useless node remove
from x2paddle.decoder.tf_decoder import TFGraph from x2paddle.decoder.tf_decoder import TFGraph
# TODO bn merge
class TFGraphOptimizer(object): # TODO activation merge
def __init__(self):
print("Doint Nothing")
def remove_isolated_node(self, graph): # TODO biasadd merge
# delete isolated nodes
isolated_nodes = list()
for node_name in graph.node_map.keys():
if len(graph.get_node(node_name).inputs) == 0 or len(
graph.get_node(node_name).outputs) == 0:
isolated_nodes.append(node_name)
graph.remove_node(node_name)
def remove_identity_node(self, graph):
identity_node = list()
for node_name, node in graph.node_map.items():
if node.layer_type == "Identity":
identity_node.append(node_name)
for node_name in identity_node:
node = graph.get_node(node_name)
# Remind: Only 1 input for Identity node
input_node = graph.get_node(node.inputs[0])
# remove identity node from graph
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
output_names = node.outputs
for output_name in output_names:
output_node = graph.get_node(output_name)
idx = output_node.inputs.index(node_name)
output_node.inputs[idx] = input_node.layer_name
idx = graph.topo_sort.index(node_name)
del graph.topo_sort[idx]
def run(self, graph):
self.remove_isolated_node(graph)
self.remove_identity_node(graph)
# TODO identity node remove
# TODO subgraph optimize
# TODO compute optimize
# activation merge
# biasadd merge
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册