提交 24e799a9 编写于 作者: J jiangjiajun

add more models support for tensorflow

上级 7ca9c323
...@@ -314,15 +314,9 @@ class TFOpMapper(OpMapper): ...@@ -314,15 +314,9 @@ class TFOpMapper(OpMapper):
input_name = input_name + "[{}]".format(input.index) input_name = input_name + "[{}]".format(input.index)
node.fluid_code.add_layer("{} = {}").format(node.layer_name, node.fluid_code.add_layer("{} = {}").format(node.layer_name,
input_name) input_name)
#
# node.fluid_code.add_layer("assign",
# inputs=input,
# output=node,
# param_attr=None)
node.tf_data_format = "NHWC" node.tf_data_format = "NHWC"
self.graph.data_format_propagation(node) self.graph.data_format_propagation(node)
elif len(input.out_shapes[0]) > 4: elif len(input.out_shapes[0]) > 4:
print(input.layer_name, input.tf_data_format, input.pd_data_format)
tf_data_format = list(input.tf_data_format) tf_data_format = list(input.tf_data_format)
pd_data_format = list(input.pd_data_format) pd_data_format = list(input.pd_data_format)
new_perm = [i for i in range(len(perm))] new_perm = [i for i in range(len(perm))]
......
...@@ -39,11 +39,6 @@ class TFOptimizer(object): ...@@ -39,11 +39,6 @@ class TFOptimizer(object):
self.graph = op_mapper.graph self.graph = op_mapper.graph
def delete_redundance_code(self): def delete_redundance_code(self):
# print("==========omit_nodes============")
# for node_name in set(self.op_mapper.omit_nodes):
# node = self.graph.get_node(node_name)
# print(node.layer_name, self.op_mapper.omit_nodes.count(node.layer_name), len(node.outputs), node.outputs)
# print("================================")
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes: if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
...@@ -67,13 +62,6 @@ class TFOptimizer(object): ...@@ -67,13 +62,6 @@ class TFOptimizer(object):
del self.graph.node_map[node_name] del self.graph.node_map[node_name]
def strip_graph(self): def strip_graph(self):
# print("=============")
# for i, node_name in enumerate(self.graph.topo_sort):
# node = self.graph.get_node(node_name)
# if node is None:
# continue
# print(node.layer_name, node.inputs)
# print("================")
visited_nodes = set() visited_nodes = set()
def visit(node_name): def visit(node_name):
...@@ -87,10 +75,6 @@ class TFOptimizer(object): ...@@ -87,10 +75,6 @@ class TFOptimizer(object):
for node_name in self.graph.output_nodes: for node_name in self.graph.output_nodes:
visit(node_name) visit(node_name)
# print("=============visited nodes++++++++++++")
# for name in visited_nodes:
# print(name)
# print("===================================")
for i, node_name in enumerate(self.graph.topo_sort): for i, node_name in enumerate(self.graph.topo_sort):
if node_name not in visited_nodes: if node_name not in visited_nodes:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
...@@ -221,9 +205,6 @@ class TFOptimizer(object): ...@@ -221,9 +205,6 @@ class TFOptimizer(object):
if out_node.layer_type == "BiasAdd": if out_node.layer_type == "BiasAdd":
del out_node.fluid_code.layers[0] del out_node.fluid_code.layers[0]
out_node.fluid_code.layers[0].inputs['x'] = last_out out_node.fluid_code.layers[0].inputs['x'] = last_out
# out_node.fluid_code.layers[0].param_attr["axis"] = 1
else: else:
del out_node.fluid_code.layers[0] del out_node.fluid_code.layers[0]
out_node.fluid_code.layers[0].inputs = last_out out_node.fluid_code.layers[0].inputs = last_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册