提交 c5b22aca 编写于 作者: J jiangjiajun

change src dir to tf2fluid

上级 e4babb02
...@@ -18,33 +18,74 @@ from tensorflow_parser import TensorflowPbParser ...@@ -18,33 +18,74 @@ from tensorflow_parser import TensorflowPbParser
from six import text_type as _text_type from six import text_type as _text_type
from utils import * from utils import *
import argparse import argparse
import logging import logging
import os import os
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def _get_parser(): def _get_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--meta_file", "-m", type=_text_type, parser.add_argument(
default=None, help="meta file path for checkpoint format") "--meta_file",
parser.add_argument("--ckpt_dir", "-c", type=_text_type, "-m",
default=None, help="checkpoint directory") type=_text_type,
parser.add_argument("--pb_file", "-p", type=_text_type, default=None,
default=None, help="pb model file path") help="meta file path for checkpoint format")
parser.add_argument("--in_nodes", "-i", type=_text_type, nargs="+", parser.add_argument(
default=None, help="input nodes name") "--ckpt_dir",
parser.add_argument("--input_shape", "-is", type=_text_type, nargs="+", "-c",
default=None, help="input tensor shape") type=_text_type,
parser.add_argument("--output_nodes", "-o", type=_text_type, nargs="+", default=None,
default=None, help="output nodes name") help="checkpoint directory")
parser.add_argument("--save_dir", "-s", type=_text_type, parser.add_argument(
default=None, help="path to save transformed paddle model") "--pb_file",
parser.add_argument("--input_format", "-sf", type=_text_type, "-p",
default=None, help="input data format(NHWC/NCHW or OTHER)") type=_text_type,
parser.add_argument("--use_cuda", "-u", type=_text_type, default=None,
default="True", help="True for use gpu") help="pb model file path")
parser.add_argument(
"--in_nodes",
"-i",
type=_text_type,
nargs="+",
default=None,
help="input nodes name")
parser.add_argument(
"--input_shape",
"-is",
type=_text_type,
nargs="+",
default=None,
help="input tensor shape")
parser.add_argument(
"--output_nodes",
"-o",
type=_text_type,
nargs="+",
default=None,
help="output nodes name")
parser.add_argument(
"--save_dir",
"-s",
type=_text_type,
default=None,
help="path to save transformed paddle model")
parser.add_argument(
"--input_format",
"-sf",
type=_text_type,
default=None,
help="input data format(NHWC/NCHW or OTHER)")
parser.add_argument(
"--use_cuda",
"-u",
type=_text_type,
default="True",
help="True for use gpu")
return parser return parser
def _convert(args):
def run(args):
if args.meta_file is None and args.pb_file is None: if args.meta_file is None and args.pb_file is None:
raise Exception("Need to define --meta_file or --pb_file") raise Exception("Need to define --meta_file or --pb_file")
if args.input_format is None: if args.input_format is None:
...@@ -78,27 +119,30 @@ def _convert(args): ...@@ -78,27 +119,30 @@ def _convert(args):
items[i] = int(items[i]) items[i] = int(items[i])
else: else:
items[i] = None items[i] = None
input_shape.append(items) input_shape.append(items)
logging.info("Loading tensorflow model...") logging.info("Loading tensorflow model...")
if args.meta_file is not None: if args.meta_file is not None:
parser = TensorflowCkptParser(args.meta_file, args.ckpt_dir, parser = TensorflowCkptParser(args.meta_file, args.ckpt_dir,
args.output_nodes, input_shape, args.in_nodes, input_format) args.output_nodes, input_shape,
args.in_nodes, input_format)
else: else:
parser = TensorflowPbParser(args.pb_file, args.output_nodes, parser = TensorflowPbParser(args.pb_file, args.output_nodes,
input_shape, args.in_nodes, input_format) input_shape, args.in_nodes, input_format)
logging.info("Tensorflow model loaded!") logging.info("Tensorflow model loaded!")
emitter = PaddleEmitter(parser, args.save_dir) emitter = PaddleEmitter(parser, args.save_dir)
emitter.run() emitter.run()
open(args.save_dir+"/__init__.py", "w").close() open(args.save_dir + "/__init__.py", "w").close()
def _main(): def _main():
parser = _get_parser() parser = _get_parser()
args = parser.parse_args() args = parser.parse_args()
_convert(args) run(args)
if __name__ == "__main__": if __name__ == "__main__":
_main() _main()
...@@ -12,35 +12,41 @@ ...@@ -12,35 +12,41 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid
import sys
class NameGenerator(object):
def __init__(self):
self.param_index = 0
self.input_index = 0
self.net_index = 0
self.const_index = 0
self.names = dict()
def get_name(self, node): class ModelLoader(object):
ref_name = None def __init__(self, model_dir, use_cuda=False):
op_name = node.layer_type sys.path.append(model_dir)
mymodel = __import__("mymodel")
self.model = mymodel.Model()
self.model.build()
self.inputs = self.model.inputs
self.outputs = self.model.outputs
if use_cuda:
self.exe = fluid.Executor(fluid.CUDAPlace(0))
else:
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe.run(fluid.default_startup_program())
if node.layer.name in self.names: var_list = list()
return self.names[node.layer.name] global_block = fluid.default_main_program().global_block()
with open(model_dir + "/save_var.list") as f:
for line in f:
try:
var = global_block.var(line.strip())
var_list.append(var)
except:
pass
fluid.io.load_vars(self.exe, model_dir, vars=var_list)
self.program = fluid.default_main_program()
if op_name == "variablev2": def save_inference_model(self, save_dir):
ref_name = "param_" + str(self.param_index) fluid.io.save_inference_model(save_dir, self.model.inputs,
self.param_index += 1 self.model.outputs, self.exe)
elif op_name == "placeholder":
ref_name = "input_" + str(self.input_index) def inference(self, feed_dict):
self.input_index += 1 result = self.exe.run(
elif op_name == "const": self.program, feed=feed_dict, fetch_list=self.model.outputs)
ref_name = "const_" + str(self.const_index) return result
self.const_index += 1
elif op_name.lower() == "identity":
ref_name = self.names[node.layer.input[0]]
else:
ref_name = "net_" + str(self.net_index)
self.net_index += 1
self.names[node.layer.name] = ref_name
return ref_name
...@@ -16,6 +16,7 @@ from graph import GraphNode, Graph ...@@ -16,6 +16,7 @@ from graph import GraphNode, Graph
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from utils import * from utils import *
class TensorflowGraphNode(GraphNode): class TensorflowGraphNode(GraphNode):
dtype_map = {1: "float32", 3: "int32", 9: "int64"} dtype_map = {1: "float32", 3: "int32", 9: "int64"}
...@@ -64,18 +65,13 @@ class TensorflowGraphNode(GraphNode): ...@@ -64,18 +65,13 @@ class TensorflowGraphNode(GraphNode):
return val if isinstance(val, bytes) else val return val if isinstance(val, bytes) else val
else: else:
return default_value return default_value
def clear_code(self): def clear_code(self):
self.code.clear() self.code.clear()
class TensorflowGraph(Graph): class TensorflowGraph(Graph):
useless_type = [ useless_type = ['identity', 'placeholderwithdefault', 'switch', 'merge']
'identity',
'placeholderwithdefault',
'switch',
'merge'
]
def __init__(self, tf_graph): def __init__(self, tf_graph):
super(TensorflowGraph, self).__init__(tf_graph) super(TensorflowGraph, self).__init__(tf_graph)
...@@ -84,7 +80,8 @@ class TensorflowGraph(Graph): ...@@ -84,7 +80,8 @@ class TensorflowGraph(Graph):
def build(self, input_format): def build(self, input_format):
skip_node = set(['const']) skip_node = set(['const'])
for i, layer in enumerate(self.tf_graph.node): for i, layer in enumerate(self.tf_graph.node):
self.node_map[layer.name] = TensorflowGraphNode(layer, input_format) self.node_map[layer.name] = TensorflowGraphNode(
layer, input_format)
for i, layer in enumerate(self.tf_graph.node): for i, layer in enumerate(self.tf_graph.node):
if layer.op.lower() in skip_node: if layer.op.lower() in skip_node:
...@@ -94,22 +91,22 @@ class TensorflowGraph(Graph): ...@@ -94,22 +91,22 @@ class TensorflowGraph(Graph):
':')[0] in self.node_map: ':')[0] in self.node_map:
pred_node = self.node_map[pred.split(':')[0]] pred_node = self.node_map[pred.split(':')[0]]
if pred_node.layer_type == "switch": if pred_node.layer_type == "switch":
self._make_connection(pred_node, self._make_connection(pred_node,
self.node_map[layer.name]) self.node_map[layer.name])
elif pred_node.layer_type == "split" or \ elif pred_node.layer_type == "split" or \
pred_node.layer_type == "splitv": pred_node.layer_type == "splitv":
self.node_map[pred] = TensorflowGraphNode( self.node_map[pred] = TensorflowGraphNode(
pred_node.layer, input_format, pred) pred_node.layer, input_format, pred)
self._make_connection(self.node_map[pred], self._make_connection(self.node_map[pred],
self.node_map[layer.name]) self.node_map[layer.name])
self._make_connection(pred_node, self.node_map[pred]) self._make_connection(pred_node, self.node_map[pred])
else: else:
raise Exception("Unsupported situation(name:[{}], \ raise Exception("Unsupported situation(name:[{}], \
OP[{}])".format(node.layer_name, node.layer_type)) OP[{}])".format(node.layer_name, node.layer_type))
elif pred in self.node_map: elif pred in self.node_map:
self._make_connection(self.node_map[pred], self._make_connection(self.node_map[pred],
self.node_map[layer.name]) self.node_map[layer.name])
else: else:
raise Exception("input: {} not in node_map".format(pred)) raise Exception("input: {} not in node_map".format(pred))
......
...@@ -19,6 +19,7 @@ from tensorflow.python.tools import strip_unused_lib ...@@ -19,6 +19,7 @@ from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
import numpy import numpy
class TensorflowCkptParser(object): class TensorflowCkptParser(object):
def __init__(self, def __init__(self,
meta_file, meta_file,
...@@ -29,21 +30,22 @@ class TensorflowCkptParser(object): ...@@ -29,21 +30,22 @@ class TensorflowCkptParser(object):
input_format="NCHW".encode()): input_format="NCHW".encode()):
graph_def = None graph_def = None
self.weights = None self.weights = None
self.inputs = in_nodes
self.outputs = dest_nodes
sess = tf.Session() sess = tf.Session()
if meta_file is None: if meta_file is None:
raise Exception("meta_file must be provided") raise Exception("meta_file must be provided")
new_saver = tf.train.import_meta_graph(meta_file) new_saver = tf.train.import_meta_graph(meta_file)
if checkpoint_file is not None: if checkpoint_file is not None:
self.weights = dict() self.weights = dict()
new_saver.restore( new_saver.restore(sess,
sess, tf.train.latest_checkpoint(checkpoint_file)) tf.train.latest_checkpoint(checkpoint_file))
for var in tf.global_variables(): for var in tf.global_variables():
value = var.eval(sess) value = var.eval(sess)
self.weights[var.name.split(':')[0]] = value self.weights[var.name.split(':')[0]] = value
self.infer = ModelInfer(sess) self.infer = ModelInfer(sess)
graph_def, ver = tf.get_default_graph()._as_graph_def( graph_def, ver = tf.get_default_graph()._as_graph_def(add_shapes=True)
add_shapes=True)
if in_nodes is not None and input_shape is not None: if in_nodes is not None and input_shape is not None:
graph_def = strip_unused_lib.strip_unused( graph_def = strip_unused_lib.strip_unused(
...@@ -58,7 +60,8 @@ class TensorflowCkptParser(object): ...@@ -58,7 +60,8 @@ class TensorflowCkptParser(object):
shape = [tf.Dimension(x) for x in input_shape[index]] shape = [tf.Dimension(x) for x in input_shape[index]]
shape_proto = tf.TensorShape(shape).as_proto() shape_proto = tf.TensorShape(shape).as_proto()
node.attr['_output_shapes'].list.shape.pop() node.attr['_output_shapes'].list.shape.pop()
node.attr['_output_shapes'].list.shape.extend([shape_proto]) node.attr['_output_shapes'].list.shape.extend(
[shape_proto])
self.infer.gen_sample_data(node.name, input_shape[index]) self.infer.gen_sample_data(node.name, input_shape[index])
self.tf_graph = TensorflowGraph(graph_def) self.tf_graph = TensorflowGraph(graph_def)
...@@ -69,14 +72,20 @@ class TensorflowCkptParser(object): ...@@ -69,14 +72,20 @@ class TensorflowCkptParser(object):
class TensorflowPbParser(object): class TensorflowPbParser(object):
def __init__(self, pb_file, dest_nodes, input_shape=None, def __init__(self,
in_nodes=None, input_format="NCHW".encode()): pb_file,
dest_nodes,
input_shape=None,
in_nodes=None,
input_format="NCHW".encode()):
with open(pb_file, 'rb') as f: with open(pb_file, 'rb') as f:
serialized = f.read() serialized = f.read()
tf.reset_default_graph() tf.reset_default_graph()
original_graph_def = tf.GraphDef() original_graph_def = tf.GraphDef()
original_graph_def.ParseFromString(serialized) original_graph_def.ParseFromString(serialized)
self.inputs = list()
self.outputs = dest_nodes
sess = tf.Session(graph=tf.get_default_graph()) sess = tf.Session(graph=tf.get_default_graph())
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
self.infer = ModelInfer(sess) self.infer = ModelInfer(sess)
...@@ -111,11 +120,11 @@ class TensorflowPbParser(object): ...@@ -111,11 +120,11 @@ class TensorflowPbParser(object):
raise Exception("Unexpected dtype for input, only support " \ raise Exception("Unexpected dtype for input, only support " \
"float32 and int32 now") "float32 and int32 now")
input_map[in_nodes[i] + ":0"] = x input_map[in_nodes[i] + ":0"] = x
self.inputs.append(x.name.split(':')[0])
self.infer.gen_sample_data(x.name, input_shape[i]) self.infer.gen_sample_data(x.name, input_shape[i])
tf.import_graph_def(graph_def, name="", input_map=input_map) tf.import_graph_def(graph_def, name="", input_map=input_map)
graph_def = tf.get_default_graph()._as_graph_def( graph_def = tf.get_default_graph()._as_graph_def(add_shapes=True)[0]
add_shapes=True)[0]
self.tf_graph = TensorflowGraph(graph_def) self.tf_graph = TensorflowGraph(graph_def)
self.tf_graph.build(input_format) self.tf_graph.build(input_format)
...@@ -164,7 +173,7 @@ class ModelInfer(object): ...@@ -164,7 +173,7 @@ class ModelInfer(object):
if len(tensor_name.split(':')) < 2: if len(tensor_name.split(':')) < 2:
tensor_name = tensor_name + ':0' tensor_name = tensor_name + ':0'
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name) output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
tensor_values = [] tensor_values = []
for i in range(0, 3): for i in range(0, 3):
inputs_tensors = dict() inputs_tensors = dict()
...@@ -175,19 +184,19 @@ class ModelInfer(object): ...@@ -175,19 +184,19 @@ class ModelInfer(object):
inputs_tensors[tensor] = values[i] inputs_tensors[tensor] = values[i]
r, = self.sess.run([output_tensor], inputs_tensors) r, = self.sess.run([output_tensor], inputs_tensors)
tensor_values.append(r.flatten()) tensor_values.append(r.flatten())
compare01 = (tensor_values[0] == tensor_values[1]) compare01 = (tensor_values[0] == tensor_values[1])
compare12 = (tensor_values[1] == tensor_values[2]) compare12 = (tensor_values[1] == tensor_values[2])
if compare01.all() and compare12.all(): if compare01.all() and compare12.all():
return tensor_values[0] return tensor_values[0]
if (compare01 == compare12).all(): if (compare01 == compare12).all():
index = numpy.argwhere(compare01==False).flatten() index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1: if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension") raise Exception("There's not only one unstable dimension")
tensor_values[0][index[0]] = -1 tensor_values[0][index[0]] = -1
index = numpy.argwhere(tensor_values[0] < 0).flatten() index = numpy.argwhere(tensor_values[0] < 0).flatten()
if index.shape[0] > 2: if index.shape[0] > 2:
raise Exception("There's more than two values less than zero") raise Exception("There's more than two values less than zero")
...@@ -199,17 +208,17 @@ class ModelInfer(object): ...@@ -199,17 +208,17 @@ class ModelInfer(object):
return tensor_values[0] return tensor_values[0]
else: else:
raise Exception("Can not infer a stable shape tensor value") raise Exception("Can not infer a stable shape tensor value")
def get_tensor_shape(self, layer): def get_tensor_shape(self, layer):
shape = layer.attr['_output_shapes'].list.shape[0] shape = layer.attr['_output_shapes'].list.shape[0]
shape = numpy.array([dim.size for dim in shape.dim]) shape = numpy.array([dim.size for dim in shape.dim])
if numpy.argwhere(shape<0).shape[0] <= 1: if numpy.argwhere(shape < 0).shape[0] <= 1:
return shape return shape
tensor_name = layer.name tensor_name = layer.name
if len(tensor_name.split(':')) < 2: if len(tensor_name.split(':')) < 2:
tensor_name = tensor_name + ':0' tensor_name = tensor_name + ':0'
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name) output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
shapes = [] shapes = []
for i in range(0, 3): for i in range(0, 3):
inputs_tensors = dict() inputs_tensors = dict()
...@@ -220,15 +229,15 @@ class ModelInfer(object): ...@@ -220,15 +229,15 @@ class ModelInfer(object):
inputs_tensors[tensor] = values[i] inputs_tensors[tensor] = values[i]
r, = self.sess.run([output_tensor], inputs_tensors) r, = self.sess.run([output_tensor], inputs_tensors)
shapes.append(numpy.array(r.shape)) shapes.append(numpy.array(r.shape))
compare01 = (shapes[0] == shapes[1]) compare01 = (shapes[0] == shapes[1])
compare12 = (shapes[1] == shapes[2]) compare12 = (shapes[1] == shapes[2])
if compare01.all() and compare12.all(): if compare01.all() and compare12.all():
return shapes[0] return shapes[0]
if (compare01 == compare12).all(): if (compare01 == compare12).all():
index = numpy.argwhere(compare01==False).flatten() index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1: if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension") raise Exception("There's not only one unstable dimension")
if index[0] != 0: if index[0] != 0:
...@@ -237,13 +246,13 @@ class ModelInfer(object): ...@@ -237,13 +246,13 @@ class ModelInfer(object):
return shapes[0] return shapes[0]
else: else:
raise Exception("Can not infer a stable tensor shape, failed!") raise Exception("Can not infer a stable tensor shape, failed!")
def get_const_tensor_value(self, layer): def get_const_tensor_value(self, layer):
tensor_name = layer.name tensor_name = layer.name
if len(tensor_name.split(':')) < 2: if len(tensor_name.split(':')) < 2:
tensor_name = tensor_name + ':0' tensor_name = tensor_name + ':0'
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name) output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
result = [] result = []
for i in range(0, 3): for i in range(0, 3):
inputs_tensors = dict() inputs_tensors = dict()
...@@ -254,10 +263,10 @@ class ModelInfer(object): ...@@ -254,10 +263,10 @@ class ModelInfer(object):
inputs_tensors[tensor] = values[i] inputs_tensors[tensor] = values[i]
r, = self.sess.run([output_tensor], inputs_tensors) r, = self.sess.run([output_tensor], inputs_tensors)
result.append(r) result.append(r)
compare01 = (result[0] == result[1]) compare01 = (result[0] == result[1])
compare12 = (result[1] == result[2]) compare12 = (result[1] == result[2])
if compare01.all() and compare12.all(): if compare01.all() and compare12.all():
return result[0] return result[0]
else: else:
......
...@@ -22,7 +22,6 @@ VALID = 'VALID'.encode() ...@@ -22,7 +22,6 @@ VALID = 'VALID'.encode()
class NameGenerator(object): class NameGenerator(object):
def __init__(self): def __init__(self):
self.param_index = 0 self.param_index = 0
self.input_index = 0
self.net_index = 0 self.net_index = 0
self.const_index = 0 self.const_index = 0
self.names = dict() self.names = dict()
...@@ -38,8 +37,7 @@ class NameGenerator(object): ...@@ -38,8 +37,7 @@ class NameGenerator(object):
ref_name = "param_" + str(self.param_index) ref_name = "param_" + str(self.param_index)
self.param_index += 1 self.param_index += 1
elif op_name == "placeholder": elif op_name == "placeholder":
ref_name = "input_" + str(self.input_index) ref_name = node.layer.name
self.input_index += 1
elif op_name == "const": elif op_name == "const":
ref_name = "const_" + str(self.const_index) ref_name = "const_" + str(self.const_index)
self.const_index += 1 self.const_index += 1
...@@ -76,11 +74,13 @@ class LayerCode(object): ...@@ -76,11 +74,13 @@ class LayerCode(object):
layer_code2 = "" layer_code2 = ""
for k, v in self.param_attr.items(): for k, v in self.param_attr.items():
layer_code2 = layer_code2 + k + "=" + "{}".format(v) + ", " layer_code2 = layer_code2 + k + "=" + "{}".format(v) + ", "
layer_code2 = layer_code2.strip(", ") layer_code2 = layer_code2.strip(", ")
layer_code = (layer_code0 + layer_code1 + layer_code2).strip(", ") + ")" layer_code = (
layer_code0 + layer_code1 + layer_code2).strip(", ") + ")"
return layer_code return layer_code
class FluidCode(object): class FluidCode(object):
def __init__(self): def __init__(self):
self.codes = list() self.codes = list()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册