提交 24e799a9 编写于 作者: J jiangjiajun

add more models support for tensorflow

上级 7ca9c323
......@@ -314,15 +314,9 @@ class TFOpMapper(OpMapper):
input_name = input_name + "[{}]".format(input.index)
node.fluid_code.add_layer("{} = {}").format(node.layer_name,
input_name)
#
# node.fluid_code.add_layer("assign",
# inputs=input,
# output=node,
# param_attr=None)
node.tf_data_format = "NHWC"
self.graph.data_format_propagation(node)
elif len(input.out_shapes[0]) > 4:
print(input.layer_name, input.tf_data_format, input.pd_data_format)
tf_data_format = list(input.tf_data_format)
pd_data_format = list(input.pd_data_format)
new_perm = [i for i in range(len(perm))]
......
......@@ -39,11 +39,6 @@ class TFOptimizer(object):
self.graph = op_mapper.graph
def delete_redundance_code(self):
# print("==========omit_nodes============")
# for node_name in set(self.op_mapper.omit_nodes):
# node = self.graph.get_node(node_name)
# print(node.layer_name, self.op_mapper.omit_nodes.count(node.layer_name), len(node.outputs), node.outputs)
# print("================================")
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
......@@ -67,13 +62,6 @@ class TFOptimizer(object):
del self.graph.node_map[node_name]
def strip_graph(self):
# print("=============")
# for i, node_name in enumerate(self.graph.topo_sort):
# node = self.graph.get_node(node_name)
# if node is None:
# continue
# print(node.layer_name, node.inputs)
# print("================")
visited_nodes = set()
def visit(node_name):
......@@ -87,10 +75,6 @@ class TFOptimizer(object):
for node_name in self.graph.output_nodes:
visit(node_name)
# print("=============visited nodes++++++++++++")
# for name in visited_nodes:
# print(name)
# print("===================================")
for i, node_name in enumerate(self.graph.topo_sort):
if node_name not in visited_nodes:
node = self.graph.get_node(node_name)
......@@ -221,9 +205,6 @@ class TFOptimizer(object):
if out_node.layer_type == "BiasAdd":
del out_node.fluid_code.layers[0]
out_node.fluid_code.layers[0].inputs['x'] = last_out
# out_node.fluid_code.layers[0].param_attr["axis"] = 1
else:
del out_node.fluid_code.layers[0]
out_node.fluid_code.layers[0].inputs = last_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册