提交 b444f36d 编写于 作者: J jiangjiajun

new structure

上级 365d13f9
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Layer(object):
def __init__(self):
self.op = None
self.param_attr = dict()
self.input = None
self.output = None
self.str_code = None
def get_code(self):
if self.str_code is not None:
return self.str_code
class FluidCode(object):
def __init__(self):
self.codes = list()
def add_layer(self, op, input, output, param_attr=None):
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils import *
import collections
class GraphNode(object):
def __init__(self, layer, layer_name=None):
self.inputs = list()
self.outputs = list()
self.layer = layer
if layer_name is not None:
self.layer_name = layer_name
else:
self.layer_name = layer.name
def __hash__(self):
return hash(self.layer.name)
def __eq__(self, other):
if self.layer.name == other.layer.name:
return True
return False
class Graph(object):
def __init__(self, model):
self.node_map = collections.OrderedDict()
self.input_nodes = list()
self.output_nodes = list()
self.topo_sort = list()
self.model = model
def build(self, input_format):
self._make_input_nodes()
self._make_output_nodes()
self._get_topo_sort()
def _make_input_nodes(self):
for name, node in self.node_map.items():
if len(node.inputs) == 0:
self.input_nodes.append(name)
def _make_output_nodes(self):
for name, node in self.node_map.items():
if len(node.outputs) == 0:
self.output_nodes.append(name)
def _get_topo_sort(self):
num_inputs = dict()
for name, node in self.node_map.items():
num_inputs[name] = len(node.inputs)
self.topo_sort = self.input_nodes[:]
while idx in range(len(self.topo_sort)):
current_node = self.node_map[self.topo_sort[idx]]
for node in current_node.outputs:
num_inputs[node.layer_name] -= 1
if num_inputs[node.layer_name] == 0:
self.topo_sort.append(node.layer_name)
def get_node(self, name):
if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name)
else:
return self.node_map[name]
def connect(self, src, dst):
if src.layer_name == dst.layer_name or src.layer_name not in \
self.node_map or dst.layer_name not in self.node_map:
raise Exception('Warning: Node not exist or there is a self-loop')
self.node_map[dst.layer_name].inputs.append(src)
self.node_map[src.layer_name].outputs.append(dst)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.graph import GraphNode, Graph
class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None):
super(TFGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op.lower()
class TFGraph(Graph):
def __init__(self, model):
super(TFGraph, self).__init__(model)
class TFParser(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):
assert in_nodes is not None, "in_nodes should not be None"
assert out_nodes is not None, "out_nodes should not be None"
assert in_shapes is not None, "in_shapes should not be None"
assert len(in_shapes) == len(in_nodes), "length of in_shapes and in_nodes should be equal"
serialized_str = open(pb_model, 'rb').read()
tf.reset_default_graph()
graph_def = tf.GraphDef()
graph_def.ParseFromString(serialized_str)
sess = tf.Session(graph=tf.get_default_graph())
sess.run(tf.global_variables_initializer())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册