提交 ec632e66 编写于 作者: J jiangjiajun

test code

上级 b444f36d
...@@ -11,3 +11,8 @@ ...@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.parser.tf_parser import TFParser
parser = TFParser('/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb',
in_nodes=['inputs'], out_nodes=['output_boxes'],
in_shapes=[[-1, 416, 416, 3]])
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from utils import *
import collections import collections
...@@ -44,7 +43,7 @@ class Graph(object): ...@@ -44,7 +43,7 @@ class Graph(object):
self.topo_sort = list() self.topo_sort = list()
self.model = model self.model = model
def build(self, input_format): def build(self):
self._make_input_nodes() self._make_input_nodes()
self._make_output_nodes() self._make_output_nodes()
self._get_topo_sort() self._get_topo_sort()
...@@ -65,7 +64,7 @@ class Graph(object): ...@@ -65,7 +64,7 @@ class Graph(object):
num_inputs[name] = len(node.inputs) num_inputs[name] = len(node.inputs)
self.topo_sort = self.input_nodes[:] self.topo_sort = self.input_nodes[:]
while idx in range(len(self.topo_sort)): for idx in range(len(self.topo_sort)):
current_node = self.node_map[self.topo_sort[idx]] current_node = self.node_map[self.topo_sort[idx]]
for node in current_node.outputs: for node in current_node.outputs:
num_inputs[node.layer_name] -= 1 num_inputs[node.layer_name] -= 1
...@@ -79,8 +78,6 @@ class Graph(object): ...@@ -79,8 +78,6 @@ class Graph(object):
return self.node_map[name] return self.node_map[name]
def connect(self, src, dst): def connect(self, src, dst):
if src.layer_name == dst.layer_name or src.layer_name not in \ if dst not in self.node_map:
self.node_map or dst.layer_name not in self.node_map: raise Exception("node[{}] not in graph".format(dst))
raise Exception('Warning: Node not exist or there is a self-loop') self.node_map[dst].inputs.append(src)
self.node_map[dst.layer_name].inputs.append(src)
self.node_map[src.layer_name].outputs.append(dst)
...@@ -13,18 +13,40 @@ ...@@ -13,18 +13,40 @@
# limitations under the License. # limitations under the License.
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from tensorflow.python.platform import gfile
import tensorflow as tf
import copy
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None):
super(TFGraphNode, self).__init__(layer, layer_name) super(TFGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op.lower() self.layer_type = layer.op
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.multi_output_ops = [
'Split',
'Unpack']
def build(self):
for layer in self.model.node:
self.node_map[layer.name] = TFGraphNode(layer)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node, layer_name)
else:
raise Exception('input[{}] of node[{}] does not exist in node_map'.format(in_node, layer_name))
else:
if self.node_map[in_node].layer_type in self.multi_output_ops:
in_node += ":0"
self.connect(in_node, layer_name)
super(TFGraph, self).build()
class TFParser(object): class TFParser(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None): def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):
...@@ -33,11 +55,14 @@ class TFParser(object): ...@@ -33,11 +55,14 @@ class TFParser(object):
assert in_shapes is not None, "in_shapes should not be None" assert in_shapes is not None, "in_shapes should not be None"
assert len(in_shapes) == len(in_nodes), "length of in_shapes and in_nodes should be equal" assert len(in_shapes) == len(in_nodes), "length of in_shapes and in_nodes should be equal"
serialized_str = open(pb_model, 'rb').read() sess = tf.Session()
tf.reset_default_graph() with gfile.FastGFile(pb_model, 'rb') as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
graph_def.ParseFromString(serialized_str) graph_def.ParseFromString(f.read())
sess.graph.as_default()
sess = tf.Session(graph=tf.get_default_graph()) tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph(sess.graph._as_graph_def(add_shapes=True)[0])
self.tf_graph.build()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册