提交 7a96c492 编写于 作者: J jiangjiajun

push tensorflow2fluid

上级 c4136505
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from name_generator import NameGenerator
class GraphNode(object):
def __init__(self, layer):
self.inputs = list()
self.outputs = list()
self.layer = layer
self.ref_name = None
self.output_name = None
def __hash__(self):
return hash(self.layer.name)
def __eq__(self, other):
if self.layer.name == other.layer.name:
return True
return False
class Graph(object):
def __init__(self, model):
self.node_map = dict()
self.input_nodes = list()
self.output_nodes = list()
self.topological_sort = list()
self.model = model
self.name_generator = NameGenerator()
def build(self):
self._make_input_nodes()
self._make_output_nodes()
self._get_topological_sort()
self._gen_newname_for_nodes()
def _make_input_nodes(self):
for name, node in self.node_map.items():
if len(node.outputs) == 0 and len(node.inputs) == 0:
continue
node.left_inputs = len(node.inputs)
if len(node.inputs) == 0:
self.input_nodes.append(name)
def _make_output_nodes(self):
for name, node in self.node_map.items():
if len(node.outputs) == 0 and len(node.inputs) == 0:
continue
if len(node.outputs) == 0:
self.output_nodes.append(name)
def _get_topological_sort(self):
self.topological_sort = self.input_nodes[:]
idx = 0
while idx < len(self.topological_sort):
current_node = self.node_map[self.topological_sort[idx]]
for next_node in current_node.outputs:
next_node_info = self.node_map[next_node.layer_name]
next_node_info.left_inputs -= 1
if next_node_info.left_inputs == 0:
self.topological_sort.append(next_node.layer_name)
idx += 1
def _gen_newname_for_nodes(self):
for node_name in self.topological_sort:
node = self.node_map[node_name]
ref_name = self.name_generator.get_name(node)
self.node_map[node.layer.name].ref_name = ref_name
self.node_map[node.layer.name].output_name = ref_name.split('[')[0]
def get_node(self, name):
if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name)
else:
return self.node_map[name]
def _make_connection(self, src, dst):
if src.layer_name == dst.layer_name or src.layer_name not in self.node_map or dst.layer_name not in self.node_map:
raise Exception('Warning: Node not exist or there is a self-loop')
if src not in self.node_map[dst.layer_name].inputs:
self.node_map[dst.layer_name].inputs.append(src)
if dst not in self.node_map[src.layer_name].outputs:
self.node_map[src.layer_name].outputs.append(dst)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class NameGenerator(object):
def __init__(self):
self.param_index = 0
self.input_index = 0
self.net_index = 0
self.const_index = 0
self.names = dict()
def get_name(self, node):
ref_name = None
op_name = node.layer_type
if node.layer.name in self.names:
return self.names[node.layer.name]
if op_name == "variablev2":
ref_name = "param_" + str(self.param_index)
self.param_index += 1
elif op_name == "placeholder":
ref_name = "input_" + str(self.input_index)
self.input_index += 1
elif op_name == "const":
ref_name = "const_" + str(self.const_index)
self.const_index += 1
elif op_name.lower() == "identity":
ref_name = self.names[node.layer.input[0]]
else:
ref_name = "net_" + str(self.net_index)
self.net_index += 1
self.names[node.layer.name] = ref_name
return ref_name
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from graph import GraphNode, Graph
from tensorflow.core.framework import attr_value_pb2
class TensorflowGraphNode(GraphNode):
def __init__(self, layer):
super(TensorflowGraphNode, self).__init__(layer)
self.codes = list()
self.data_format = 'NCHW'
@property
def layer_type(self):
return self.layer.op.lower()
@property
def layer_name(self):
return self.layer.name
def get_attr(self, name, default_value=None):
if name in self.layer.attr:
attr = self.layer.attr[name]
field = attr.WhichOneof('value')
val = getattr(attr, field) if field else default_value
if isinstance(val, attr_value_pb2.AttrValue.ListValue):
return list(val.ListFields()[0][1])
else:
return val.decode('utf-8') if isinstance(val, bytes) else val
else:
return default_value
class TensorflowGraph(Graph):
def __init__(self, tf_graph):
super(TensorflowGraph, self).__init__(tf_graph)
self.tf_graph = tf_graph
def build(self):
skip_node = set(['const'])
for i, layer in enumerate(self.tf_graph.node):
self.node_map[layer.name] = TensorflowGraphNode(layer)
for i, layer in enumerate(self.tf_graph.node):
if layer.op.lower() in skip_node:
continue
for pred in layer.input:
if pred not in self.node_map and pred.split(
':')[0] in self.node_map:
node = self.node_map[pred.split(':')[0]]
if node.layer_type == "switch":
self._make_connection(node, self.node_map[layer.name])
else:
raise Exception("Need to fix here")
elif pred in self.node_map:
self._make_connection(self.node_map[pred],
self.node_map[layer.name])
else:
raise Exception("input: {} not in node_map".format(pred))
super(TensorflowGraph, self).build()
self._remove_useless_nodes()
self._check_dataformat()
def _check_dataformat(self):
ss = list()
for i in range(0, len(self.topological_sort)):
current_node = self.node_map[self.topological_sort[i]]
if 'data_format' in current_node.layer.attr:
s = current_node.layer.attr['data_format'].s
if s != 'NHWC' and s != 'NCHW':
raise Exception('Unkown dataformat {}'.format(s))
ss.append(s)
if len(set(ss)) > 1:
raise Exception("Two type of dataformat exist in this model")
if len(set(ss)) == 0:
return
for k, v in self.node_map.items():
self.node_map[k].data_format = ss[0]
def _remove_useless_nodes(self):
useless_type = set(
['identity', 'placeholderwithdefault', 'switch', 'merge'])
remove_index = list()
for i in range(0, len(self.topological_sort)):
name = self.topological_sort[i]
current_node = self.node_map[name]
if current_node.layer_type in useless_type:
input = current_node.inputs[0]
for node in current_node.outputs:
for k in range(0, len(node.inputs)):
if node.inputs[k] == current_node:
node.inputs[k] = input
if node not in input.outputs:
input.outputs.append(node)
input.outputs.remove(current_node)
del self.node_map[name]
if name in self.output_nodes:
self.output_nodes.remove(name)
if name in self.input_nodes:
self.input_nodes.remove(name)
remove_index.append(i)
remove_index.sort(reverse=True)
for i in range(0, len(remove_index)):
del self.topological_sort[remove_index[i]]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow
from tensorflow_graph import TensorflowGraph
from tensorflow.python.framework import tensor_util
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
class TensorflowCkptParser(object):
def __init__(self,
meta_file,
checkpoint_file,
dest_nodes,
input_shape=None,
in_nodes=None):
graph_def = None
self.weights = None
with tensorflow.Session() as sess:
if meta_file is None:
raise Exception("meta_file must be provided")
new_saver = tensorflow.train.import_meta_graph(meta_file)
if checkpoint_file is not None:
self.weights = dict()
new_saver.restore(
sess, tensorflow.train.latest_checkpoint(checkpoint_file))
for var in tensorflow.global_variables():
value = var.eval(sess)
self.weights[var.name.split(':')[0]] = value
graph_def, ver = tensorflow.get_default_graph()._as_graph_def(
add_shapes=True)
if in_nodes is not None and input_shape is not None:
graph_def = strip_unused_lib.strip_unused(
input_graph_def=graph_def,
input_node_names=in_nodes,
output_node_names=dest_nodes,
placeholder_type_enum=dtypes.float32.as_datatype_enum)
self.tf_graph = TensorflowGraph(graph_def)
else:
raise Exception('in_nodes and output_nodes need be provided')
self.tf_graph.build()
class TensorflowPbParser(object):
def __init__(self, pb_file, dest_nodes, input_shape=None, in_nodes=None):
with open(pb_file) as f:
serialized = f.read()
tensorflow.reset_default_graph()
original_graph_def = tensorflow.GraphDef()
original_graph_def.ParseFromString(serialized)
original_graph_def = strip_unused_lib.strip_unused(
input_graph_def=original_graph_def,
input_node_names=in_nodes,
output_node_names=dest_nodes,
placeholder_type_enum=dtypes.float32.as_datatype_enum)
graph_def = tensorflow.GraphDef()
graph_def.ParseFromString(original_graph_def.SerializeToString())
in_type_list = dict()
for node in graph_def.node:
if node.name in in_nodes:
in_type_list[node.name] = node.attr['dtype'].type
input_shape = list(input_shape)
if not isinstance(input_shape[0], list):
input_shape = [input_shape]
input_map = dict()
for i in range(len(input_shape)):
if in_type_list[in_nodes[i]] == 1 or in_type_list[
in_nodes[i]] == 0:
dtype = tensorflow.float32
x = tensorflow.placeholder(dtype, shape=input_shape[i])
elif in_type_list[in_nodes[i]] == 3:
dtype = tensorflow.int32
x = tensorflow.placehoder(dtype, shape=input_shape[i])
else:
raise Exception(
"Unexpected dtype for input, only support float32 and int32 now"
)
input_map[in_nodes[i] + ":0"] = x
tensorflow.import_graph_def(graph_def, name="", input_map=input_map)
graph_def = tensorflow.get_default_graph()._as_graph_def(
add_shapes=True)[0]
node = graph_def.node[0]
self.tf_graph = TensorflowGraph(graph_def)
self.tf_graph.build()
self.weights = dict()
for node in graph_def.node:
if node.op.lower() == "const":
try:
node.attr['value'].tensor.tensor_content
weight = tensor_util.MakeNdarray(node.attr['value'].tensor)
self.weights[node.name] = weight
except:
continue
from paddle_emitter import PaddleEmitter
from tensorflow_parser import TensorflowCkptParser
from tensorflow_parser import TensorflowPbParser
class Transformer(object):
def __init__(self, meta_file, ckpt_file, out_nodes, in_shape, in_nodes, save_dir):
self.parser = TensorflowCkptParser(meta_file, ckpt_file, out_nodes,
in_shape, in_nodes)
self.emitter = PaddleEmitter(self.parser, save_dir)
def transform_code(self):
codes = self.emitter.run()
def run(self):
self.transform_code()
class PbTransformer(object):
def __init__(self, pb_file, out_nodes, in_shape, in_nodes, save_dir):
self.parser = TensorflowPbParser(pb_file, out_nodes, in_shape, in_nodes)
self.emitter = PaddleEmitter(self.parser, save_dir)
node = self.parser.tf_graph.tf_graph.node[0]
def transform_code(self):
codes = self.emitter.run()
def run(self):
self.transform_code()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册