提交 51fcdabb 编写于 作者: J jiangjiajun

finish parser emitter optimizer demo

上级 ce716f9c
......@@ -13,11 +13,16 @@
# limitations under the License.
from x2paddle.parser.tf_parser import TFParser
from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer
from x2paddle.emitter.tf_emitter import TFEmitter
parser = TFParser('/ssd3/dltpsz/frozen_darknet_yolov3_model.pb',
in_nodes=['inputs'],
out_nodes=['output_boxes'],
in_shapes=[[-1, 416, 416, 3]])
parser = TFParser('/ssd3/dltpsz/frozen_darknet_yolov3_model.pb',
in_nodes=['inputs'], out_nodes=['output_boxes'],
in_shapes=[[-1, 416, 416, 3]])
optimizer = TFGraphOptimizer()
optimizer.remove_useless_node(parser.tf_graph)
parser.tf_graph.print()
optimizer.run(parser.tf_graph)
#parser.tf_graph.print()
emitter = TFEmitter(parser)
emitter.run()
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.graph import GraphNode
class Layer(object):
def __init__(self):
......@@ -23,18 +25,22 @@ class Layer(object):
def get_code(self):
layer_code = ""
if self.output is not None:
layer_code = self.output + " = "
if isinstance(self.output, str):
layer_code = self.output + " = "
else:
layer_code = self.output.layer_name + " = "
layer_code = layer_code + "fluid.layers." + self.op + "("
for key, tensor in self.inputs.items():
layer_code = layer_code + key + "=" + tensor + ", "
layer_code = layer_code + key + "={}, ".format(tensor)
for key, value in self.param_attr.items():
layer_code = layer_code + key + "=" + value + ", "
layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ")
return layer_code += ")"
return layer_code + ")"
class FluidCode(object):
def __init__(self):
......@@ -43,7 +49,8 @@ class FluidCode(object):
def add_layer(self, op, inputs, output, param_attr=None):
layer = Layer()
layer.op = op
layer.inputs = inputs
if inputs is not None:
layer.inputs = inputs
layer.output = output
if param_attr is not None:
layer.param_attr = param_attr
......
......@@ -99,7 +99,8 @@ class Graph(object):
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def print(self):
for i, tmp in enumerate(self.topo_sort):
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs)
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs,
self.node_map[tmp].outputs)
......@@ -14,7 +14,35 @@
from x2paddle.parser.tf_parser import TFGraph
from x2paddle.core.emitter import Emitter
from x2paddle.core.fluid_code import FluidCode
class TFEmitter(Emitter):
def __init__(self):
def __init__(self, parser):
super(TFEmitter, self).__init__()
self.parser = parser
self.graph = parser.tf_graph
self.fluid_code = FluidCode()
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort)))
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self, op):
emit_func = getattr(self, op)
emit_func(node)
def Placeholder(self, node):
shape = node.out_shapes[0]
dtype = node.dtype
attr = {
'dtype': '\{}\''.format(dtype),
'shape': shape,
'name': '\'{}\''.format(node.layer_name)
}
self.fluid_code.add_layer("data",
inputs=inputs,
output=node,
param_attr=attr)
print(self.fluid_code.layers[0].get_code())
......@@ -18,14 +18,21 @@ from x2paddle.parser.tf_parser import TFGraph
class TFGraphOptimizer(object):
def __init__(self):
print("Not Implement")
self.useless_op = [
'NoOp']
def remove_useless_node(self, graph):
for node_name, node in graph.node_map.items():
if node.layer_type in self.useless_op:
graph.remove_node(node_name)
self.identity_ops = ['Identity']
def remove_isolated_node(self, graph):
# delete isolated nodes
isolated_nodes = list()
for node_name in graph.node_map.keys():
if len(graph.get_node(node_name).inputs) == 0 or len(
graph.get_node(node_name).outputs) == 0:
isolated_nodes.append(node_name)
graph.remove_node(node_name)
def run(self, graph):
self.remove_isolated_node(graph)
# TODO identity node remove
......
......@@ -17,6 +17,7 @@ from tensorflow.python.platform import gfile
import tensorflow as tf
import copy
class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None):
if layer_name is None:
......@@ -25,6 +26,24 @@ class TFGraphNode(GraphNode):
super(TFGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op
self.dtype_map = {1: "float32", 3: "int32", 9: "int64"}
@property
def out_shapes(self):
values = self.layer.attr["_output_shapes"].list.shape
out_shapes = list()
for value in values:
shape = [dim.size for dim in value.dim]
out_shapes.append(shape)
return out_shapes
@property
def dtype(self):
dtype = self.layer.attr["dtype"].type
if dtype not in self.dtype_map:
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype]
class TFGraph(Graph):
def __init__(self, model):
......@@ -40,11 +59,13 @@ class TFGraph(Graph):
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node.strip().split(':')[0], layer_name)
else:
raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name))
raise Exception(
'input[{}] of node[{}] does not exist in node_map'.
format(in_node, layer_name))
else:
self.connect(in_node, layer_name)
super(TFGraph, self).build()
super(TFGraph, self).build()
class TFParser(object):
......@@ -52,7 +73,8 @@ class TFParser(object):
assert in_nodes is not None, "in_nodes should not be None"
assert out_nodes is not None, "out_nodes should not be None"
assert in_shapes is not None, "in_shapes should not be None"
assert len(in_shapes) == len(in_nodes), "length of in_shapes and in_nodes should be equal"
assert len(in_shapes) == len(
in_nodes), "length of in_shapes and in_nodes should be equal"
sess = tf.Session()
with gfile.FastGFile(pb_model, 'rb') as f:
......@@ -60,7 +82,7 @@ class TFParser(object):
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册