提交 fef5149c 编写于 作者: J jiangjiajun

add NCHW change

上级 0a1078ca
......@@ -116,7 +116,7 @@ class OpMapper(object):
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename="__params__")
params_filename=None)
except:
raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
......
......@@ -24,7 +24,7 @@ import sys
class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None):
def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None:
super(TFGraphNode,
self).__init__(layer,
......@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode):
layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
......@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode):
class TFGraph(Graph):
def __init__(self, model):
def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model)
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV']
self.tf_data_format = data_format
def build(self):
for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer)
'-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
......@@ -166,9 +169,20 @@ class TFGraph(Graph):
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
outputs = current_node.outputs
if len(outputs) == 0:
return
for out in outputs:
next_node = self.node_map[out]
next_node.tf_data_format = node.tf_data_format
self.data_format_propagation(next_node)
class TFDecoder(object):
def __init__(self, pb_model):
def __init__(self, pb_model, data_format="NHWC"):
self.sess = tf.Session()
self.input_info = dict()
with gfile.FastGFile(pb_model, 'rb') as f:
......@@ -186,7 +200,7 @@ class TFDecoder(object):
self.sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph(
self.sess.graph._as_graph_def(add_shapes=True)[0])
self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
self.tf_graph.build()
def _fix_output_shape(self, graph):
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册