提交 fef5149c 编写于 作者: J jiangjiajun

add NCHW change

上级 0a1078ca
...@@ -116,7 +116,7 @@ class OpMapper(object): ...@@ -116,7 +116,7 @@ class OpMapper(object):
feeded_var_names=input_names, feeded_var_names=input_names,
target_vars=outputs, target_vars=outputs,
executor=exe, executor=exe,
params_filename="__params__") params_filename=None)
except: except:
raise Exception( raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually." "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
......
...@@ -24,7 +24,7 @@ import sys ...@@ -24,7 +24,7 @@ import sys
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None: if layer_name is None:
super(TFGraphNode, super(TFGraphNode,
self).__init__(layer, self).__init__(layer,
...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode): ...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode):
layer_name.replace('/', '_').replace('-', '_')) layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"} self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode): ...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode):
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV'] self.multi_out_ops = ['Split', 'SplitV']
self.tf_data_format = data_format
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace( self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer) '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
for in_node in node.layer.input: for in_node in node.layer.input:
...@@ -166,9 +169,20 @@ class TFGraph(Graph): ...@@ -166,9 +169,20 @@ class TFGraph(Graph):
idx = self.output_nodes.index(node_name) idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
outputs = current_node.outputs
if len(outputs) == 0:
return
for out in outputs:
next_node = self.node_map[out]
next_node.tf_data_format = node.tf_data_format
self.data_format_propagation(next_node)
class TFDecoder(object): class TFDecoder(object):
def __init__(self, pb_model): def __init__(self, pb_model, data_format="NHWC"):
self.sess = tf.Session() self.sess = tf.Session()
self.input_info = dict() self.input_info = dict()
with gfile.FastGFile(pb_model, 'rb') as f: with gfile.FastGFile(pb_model, 'rb') as f:
...@@ -186,7 +200,7 @@ class TFDecoder(object): ...@@ -186,7 +200,7 @@ class TFDecoder(object):
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph( self.tf_graph = TFGraph(
self.sess.graph._as_graph_def(add_shapes=True)[0]) self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
self.tf_graph.build() self.tf_graph.build()
def _fix_output_shape(self, graph): def _fix_output_shape(self, graph):
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册