提交 33166c4e 编写于 作者: J jiangjiajun

first commit

上级 df8bfe33
Warning: TensorFlow2Paddle is not stable yet
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from name_generator import NameGenerator
class GraphNode(object):
def __init__(self, layer):
self.in_edges = list()
self.out_edges = list()
self.layer = layer
self.ref_name = None
class Graph(object):
def __init__(self, model):
self.node_map = dict()
self.input_nodes = list()
self.output_nodes = list()
self.topological_sort = list()
self.model = model
self.name_generator = NameGenerator()
def build(self):
self._make_input_nodes()
self._make_output_nodes()
self._get_topological_sort()
self._gen_newname_for_nodes()
def _make_input_nodes(self):
for name, node in self.node_map.items():
node.left_in_edges = len(node.in_edges)
if len(node.in_edges) == 0:
self.input_nodes.append(name)
def _make_output_nodes(self):
for name, node in self.node_map.items():
if len(node.out_edges) == 0:
self.output_nodes.append(name)
def _get_topological_sort(self):
self.topological_sort = self.input_nodes[:]
idx = 0
while idx < len(self.topological_sort):
current_node = self.node_map[self.topological_sort[idx]]
for next_node in current_node.out_edges:
next_node_info = self.node_map[next_node]
next_node_info.left_in_edges -= 1
if next_node_info.left_in_edges == 0:
self.topological_sort.append(next_node)
idx += 1
def _gen_newname_for_nodes(self):
for node_name in self.topological_sort:
node = self.node_map[node_name]
ref_name = self.name_generator.get_name(node)
self.node_map[node.layer.name].ref_name = ref_name
def get_node(self, name):
if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name)
else:
return self.node_map[name]
def _make_connection(self, src, dst):
if src == dst or src not in self.node_map or dst not in self.node_map:
raise Exception('Warning: Node not exist or there is a self-loop')
if src not in self.node_map[dst].in_edges:
self.node_map[dst].in_edges.append(src)
if dst not in self.node_map[src].out_edges:
self.node_map[src].out_edges.append(dst)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class NameGenerator(object):
def __init__(self):
self.param_index = 0
self.input_index = 0
self.net_index = 0
self.const_index = 0
self.names = dict()
def get_name(self, node):
ref_name = None
op_name = node.type
if node.layer.name in self.names:
return self.names[node.layer.name]
if op_name == "variablev2":
ref_name = "param_" + str(self.param_index)
self.param_index += 1
elif op_name == "placeholder":
ref_name = "input_" + str(self.input_index)
self.input_index += 1
elif op_name == "const":
ref_name = "const_" + str(self.const_index)
self.const_index += 1
elif op_name.lower() == "identity":
ref_name = self.names[node.layer.input[0]]
else:
ref_name = "net_" + str(self.net_index)
self.net_index += 1
self.names[node.layer.name] = ref_name
return ref_name
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from graph import GraphNode, Graph
from tensorflow.core.framework import attr_value_pb2
class TensorflowGraphNode(GraphNode):
def __init__(self, layer):
super(TensorflowGraphNode, self).__init__(layer)
self.codes = list()
self.dataformat = 'NCHW'
@property
def type(self):
return self.layer.op.lower()
@property
def name(self):
return self.layer.name
# TODO
def get_attr(self, name, default_value=None):
if name in self.layer.attr:
attr = self.layer.attr[name]
field = attr.WhichOneof('value')
val = getattr(attr, field) if field else default_value
if isinstance(val, attr_value_pb2.AttrValue.ListValue):
return list(val.ListFields()[0][1])
else:
return val.decode('utf-8') if isinstance(val, bytes) else val
else:
return default_value
class TensorflowGraph(Graph):
def __init__(self, model):
super(TensorflowGraph, self).__init__(model)
self.model = model
def build(self):
for i, layer in enumerate(self.model.node):
self.node_map[layer.name] = TensorflowGraphNode(layer)
for pred in layer.input:
if pred not in self.node_map:
raise Exception('input: {} not in node_map'.format(pred))
self._make_connection(pred, layer.name)
super(TensorflowGraph, self).build()
self._check_dataformat()
# check the dataformat of network
def _check_dataformat(self):
ss = list()
for i in range(0, len(self.topological_sort)):
current_node = self.node_map[self.topological_sort[i]]
if current_node.type == 'conv2d':
s = current_node.layer.attr['data_format'].s
if s != 'NHWC' and s != 'NCHW':
raise Exception('Unkown dataformat {}'.format(s))
ss.append(s)
if len(set(ss)) > 1:
raise Exception("Two type of dataformat exist in this model")
if len(set(ss)) == 0:
return
for i in range(0, len(self.topological_sort)):
current_node = self.node_map[self.topological_sort[i]]
current_node.dataformat = ss[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow
from tensorflow_graph import TensorflowGraph
class TensorflowParser(object):
def __init__(self,
meta_file,
checkpoint_file,
dest_nodes,
input_shape=None,
in_nodes=None):
graph_def = None
self.weights = dict()
with tensorflow.Session() as sess:
if meta_file is None:
raise Exception("meta_file must be provided")
new_saver = tensorflow.train.import_meta_graph(meta_file)
if checkpoint_file is not None:
new_saver.restore(
sess, tensorflow.train.latest_checkpoint(checkpoint_file))
for var in tensorflow.global_variables():
value = var.eval(sess)
self.weights[var.name.split(':')[0]] = value
graph_def, ver = tensorflow.get_default_graph()._as_graph_def(
add_shapes=True)
if in_nodes is not None and input_shape is not None:
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
graph_def = strip_unused_lib.strip_unused(
input_graph_def=graph_def,
input_node_names=in_nodes,
output_node_names=dest_nodes,
placeholder_type_enum=dtypes.float32.as_datatype_enum)
input_list = [None]
for i in range(len(input_shape)):
input_list.append(tensorflow.Dimension(input_shape[i]))
tensor_input = tensorflow.TensorShape(input_list)
self.tf_graph = TensorflowGraph(graph_def)
for node in self.tf_graph.model.node:
if node.name in in_nodes:
node.attr['shape'].list.shape.extend(
[tensor_input.as_proto()])
node.attr['_output_shapes'].list.shape.pop()
node.attr['_output_shapes'].list.shape.extend(
[tensor_input.as_proto()])
else:
raise Exception('in_nodes and output_nodes need be provided')
self.tf_graph.build()
from paddle_emitter import PaddleEmitter
from tensorflow_parser import TensorflowParser
class Transformer(object):
def __init__(self, meta_file, ckpt_file, out_nodes, in_shape, in_nodes):
self.parser = TensorflowParser(meta_file, ckpt_file, out_nodes,
in_shape, in_nodes)
self.emitter = PaddleEmitter(self.parser.tf_graph)
def transform_code(self, out_py_file):
filew = open(out_py_file, 'w')
codes = self.emitter.gen_code()
filew.write(codes)
filew.close()
def transform_weight(self, out_dir):
self.emitter.gen_weight(self.parser.weights, out_dir)
def run(self, dst_dir):
import os
if os.path.isdir(dst_dir) or os.path.isfile(dst_dir):
print("{} already exists, set a new directory")
return
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
self.transform_code(dst_dir + "/mymodel.py")
if (len(self.parser.weights) == 0):
print("There is no tensorflow model weight translate to paddle")
else:
self.transform_weight(dst_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册