diff --git a/x2paddle/convert.py b/x2paddle/convert.py index a028662d70d6814266a52ccc6826e640c7671b7e..f897dc9bda711a3936c055f5a3822320f7f21674 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from x2paddle.parser.tf_parser import TFParser + +parser = TFParser('/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb', + in_nodes=['inputs'], out_nodes=['output_boxes'], + in_shapes=[[-1, 416, 416, 3]]) diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 9ca89e64379676b19fe0cdcf98da3f4f3f781fad..1ec62280d6fac87deca91a2eacaaa8bd7a866cac 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from utils import * import collections @@ -44,7 +43,7 @@ class Graph(object): self.topo_sort = list() self.model = model - def build(self, input_format): + def build(self): self._make_input_nodes() self._make_output_nodes() self._get_topo_sort() @@ -65,7 +64,7 @@ class Graph(object): num_inputs[name] = len(node.inputs) self.topo_sort = self.input_nodes[:] - while idx in range(len(self.topo_sort)): + for idx in range(len(self.topo_sort)): current_node = self.node_map[self.topo_sort[idx]] for node in current_node.outputs: num_inputs[node.layer_name] -= 1 @@ -79,8 +78,6 @@ class Graph(object): return self.node_map[name] def connect(self, src, dst): - if src.layer_name == dst.layer_name or src.layer_name not in \ - self.node_map or dst.layer_name not in self.node_map: - raise Exception('Warning: Node not exist or there is a self-loop') - self.node_map[dst.layer_name].inputs.append(src) - self.node_map[src.layer_name].outputs.append(dst) + if dst not in self.node_map: + raise Exception("node[{}] not in graph".format(dst)) + self.node_map[dst].inputs.append(src) diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index af880481c292f8a5099e12630d1af5a737e682ef..21d62a6aef665c4867e9e069b3540485bfe02614 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -13,18 +13,40 @@ # limitations under the License. from x2paddle.core.graph import GraphNode, Graph - +from tensorflow.python.platform import gfile +import tensorflow as tf +import copy class TFGraphNode(GraphNode): def __init__(self, layer, layer_name=None): super(TFGraphNode, self).__init__(layer, layer_name) - self.layer_type = layer.op.lower() + self.layer_type = layer.op class TFGraph(Graph): def __init__(self, model): super(TFGraph, self).__init__(model) + self.multi_output_ops = [ + 'Split', + 'Unpack'] + + def build(self): + for layer in self.model.node: + self.node_map[layer.name] = TFGraphNode(layer) + for layer_name, node in self.node_map.items(): + for in_node in node.layer.input: + if in_node not in self.node_map: + if in_node.strip().split(':')[0] in self.node_map: + self.connect(in_node, layer_name) + else: + raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name)) + else: + if self.node_map[in_node].layer_type in self.multi_output_ops: + in_node += ":0" + self.connect(in_node, layer_name) + + super(TFGraph, self).build() class TFParser(object): def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None): @@ -33,11 +55,14 @@ class TFParser(object): assert in_shapes is not None, "in_shapes should not be None" assert len(in_shapes) == len(in_nodes), "length of in_shapes and in_nodes should be equal" - serialized_str = open(pb_model, 'rb').read() - tf.reset_default_graph() - graph_def = tf.GraphDef() - graph_def.ParseFromString(serialized_str) - - sess = tf.Session(graph=tf.get_default_graph()) - sess.run(tf.global_variables_initializer()) + sess = tf.Session() + with gfile.FastGFile(pb_model, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + sess.graph.as_default() + tf.import_graph_def(graph_def, name='') + sess.run(tf.global_variables_initializer()) + + self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0]) + self.tf_graph.build()