提交 67fe580d 编写于 作者: J jiangjiajun

add demo & doc

上级 e75e3a67
from src.paddle_emitter import PaddleEmitter
from src.tensorflow_parser import TensorflowCkptParser
from src.tensorflow_parser import TensorflowPbParser
from six import text_type as _text_type
import argparse
import sys
import os
def _get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--meta_file", "-m", type=_text_type, default=None, help="meta file path for checkpoint format")
parser.add_argument("--ckpt_dir", "-c", type=_text_type, default=None, help="checkpoint directory")
parser.add_argument("--pb_file", "-p", type=_text_type, default=None, help="pb model file path")
parser.add_argument("--in_nodes", "-i", type=_text_type, nargs="+", default=None, help="input nodes name")
parser.add_argument("--input_shape", "-is", type=_text_type, nargs="+", default=None, help="input tensor shape")
parser.add_argument("--output_nodes", "-o", type=_text_type, nargs="+", default=None, help="output nodes name")
parser.add_argument("--save_dir", "-s", type=_text_type, default=None, help="path to save transformed paddle model")
return parser
def _convert(args):
if args.meta_file is None and args.pb_file is None:
raise Exception("Need to define --meta_file or --pb_file")
assert args.in_nodes is not None
assert args.output_nodes is not None
assert args.input_shape is not None
assert args.save_dir is not None
if os.path.exists(args.save_dir):
sys.stderr.write("save_dir already exists, change to a new path\n")
return
os.makedirs(args.save_dir)
input_shape = list()
for shape_str in args.input_shape:
items = shape_str.split(',')
for i in range(len(items)):
if items[i] != "None":
items[i] = int(items[i])
input_shape.append(items)
if args.meta_file is not None:
parser = TensorflowCkptParser(args.meta_file, args.ckpt_dir, args.output_nodes, input_shape, args.in_nodes)
else:
parser = TensorflowPbParser(args.pb_file, args.output_nodes, input_shape, args.in_nodes)
emitter = PaddleEmitter(parser, args.save_dir)
emitter.run()
if __name__ == "__main__":
parser = _get_parser()
args = parser.parse_args()
_convert(args)
from tensorflow.contrib.slim.nets import inception
from tensorflow.contrib.framework.python.ops import arg_scope
import tensorflow.contrib.slim as slim
import tensorflow as tf
import numpy
import sys
numpy.random.seed(13)
ckpt_file = sys.argv[1]
checkpoint_dir = sys.argv[2]
def get_tuned_variables():
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
def load_model():
img_size = inception.inception_v3.default_image_size
img = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits, _ = inception.inception_v3(img, num_classes=1000, is_training=False)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file, get_tuned_variables(), ignore_missing_vars=True)
load_model(sess)
return sess
def save_checkpoint(sess):
saver = tf.train.Saver()
saver.save(sess, checkpoint_dir+"/model")
if __name__ == "__main__":
sess = load_model()
save_checkpoint(sess)
......@@ -124,8 +124,8 @@ class PaddleEmitter(object):
ref_name_recorder = open(self.save_dir + "/ref_name.txt", 'w')
total_nodes_num = len(self.graph.topological_sort)
translated_nodes_count = 0
sys.stderr.write("Start to translate all the nodes(Total_num:{})\n".
translated_nodes_count = 1
sys.stderr.write("\nStart to translate all the nodes(Total_num:{})\n".
format(total_nodes_num))
for node in self.graph.topological_sort:
sys.stderr.write(
......@@ -168,6 +168,9 @@ class PaddleEmitter(object):
filew.write(self.body_code)
filew.close()
sys.stderr.write("Model translated!\n\n")
sys.stderr.flush()
return self.body_code
def emit_placeholder(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册