提交 e5fdd6d4 编写于 作者: J jiangjiajun

add setup.py

上级 db77a07f
from tensorflow.contrib.slim.nets import inception
from tensorflow.contrib.slim.nets import vgg as vgg
from tensorflow.contrib.slim.nets import resnet_v1 as resnet_v1
from tensorflow.contrib.framework.python.ops import arg_scope
import tensorflow.contrib.slim as slim
import tensorflow as tf
import numpy
from six import text_type as _text_type
def inception_v3(ckpt_file):
def get_tuned_variables():
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
img_size = inception.inception_v3.default_image_size
img = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits, _ = inception.inception_v3(img, num_classes=1000, is_training=False)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file, get_tuned_variables(), ignore_missing_vars=True)
load_model(sess)
return sess
def resnet_v1_50(ckpt_file):
img_size = resnet_v1.resnet_v1.default_image_size
img = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, endpoint = resnet_v1.resnet_v1_50(img, num_classes=None, is_training=False)
sess = tf.Session()
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file, tf.contrib.slim.get_model_variables("resnet_v1_50"))
load_model(sess)
return sess
def resnet_v1_101(ckpt_file):
img_size = resnet_v1.resnet_v1.default_image_size
img = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, endpoint = resnet_v1.resnet_v1_101(img, num_classes=None, is_training=False)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file, tf.contrib.slim.get_model_variables("resnet_v1_101"))
load_model(sess)
return sess
def vgg_16(ckpt_file):
img_size = vgg.vgg_16.default_image_size
inputs = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3],
name="inputs")
logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)
sess = tf.Session()
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file,
tf.contrib.slim.get_model_variables("vgg_16"))
load_model(sess)
return sess
def vgg_19(ckpt_file):
img_size = vgg.vgg_19.default_image_size
inputs = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3],
name="inputs")
logits, endpoint = vgg.vgg_19(inputs, num_classes=1000, is_training=False)
sess = tf.Session()
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file,
tf.contrib.slim.get_model_variables("vgg_19"))
load_model(sess)
return sess
def save_checkpoint(sess, save_dir):
saver = tf.train.Saver()
saver.save(sess, save_dir+"/model")
def get_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=_text_type, default=None, help="inception_v3/resnet_v1_50/resnet_v1_101/vgg_16/vgg_19")
parser.add_argument("--ckpt_file", "-c", type=_text_type, default=None, help="parameters ckpt file")
parser.add_argument("--save_dir", "-s", type=_text_type, default=None, help="model path")
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
sess = None
if args.model is None or args.save_dir is None or args.ckpt_file is None:
raise Exception("--model, --ckpt_file and --save_dir are needed")
if args.model == "inception_v3":
sess = inception_v3(args.ckpt_file)
elif args.model == "resnet_v1_50":
sess = resnet_v1_50(args.ckpt_file)
elif args.model == "resnet_v1_101":
sess = resnet_v1_101(args.ckpt_file)
elif args.model == "vgg_16":
sess = vgg_16(args.ckpt_file)
elif args.model == "vgg_19":
sess = vgg_19(args.ckpt_file)
else:
raise Exception("Only support inception_v3/resnet_v1_50/resnet_v1_101/vgg_16/vgg_19")
save_checkpoint(sess, args.save_dir)
from tensorflow.contrib.slim.nets import inception
from tensorflow.contrib.framework.python.ops import arg_scope
import tensorflow.contrib.slim as slim
import tensorflow as tf
import numpy
import sys
numpy.random.seed(13)
ckpt_file = sys.argv[1]
checkpoint_dir = sys.argv[2]
def get_tuned_variables():
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
def load_model():
img_size = inception.inception_v3.default_image_size
img = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits, _ = inception.inception_v3(img, num_classes=1000, is_training=False)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
load_model = tf.contrib.slim.assign_from_checkpoint_fn(ckpt_file, get_tuned_variables(), ignore_missing_vars=True)
load_model(sess)
return sess
def save_checkpoint(sess):
saver = tf.train.Saver()
saver.save(sess, checkpoint_dir+"/model")
if __name__ == "__main__":
sess = load_model()
save_checkpoint(sess)
from src.paddle_emitter import PaddleEmitter
from src.tensorflow_parser import TensorflowCkptParser
from src.tensorflow_parser import TensorflowPbParser
from paddle_emitter import PaddleEmitter
from tensorflow_parser import TensorflowCkptParser
from tensorflow_parser import TensorflowPbParser
from six import text_type as _text_type
import argparse
import sys
......@@ -15,6 +15,7 @@ def _get_parser():
parser.add_argument("--input_shape", "-is", type=_text_type, nargs="+", default=None, help="input tensor shape")
parser.add_argument("--output_nodes", "-o", type=_text_type, nargs="+", default=None, help="output nodes name")
parser.add_argument("--save_dir", "-s", type=_text_type, default=None, help="path to save transformed paddle model")
parser.add_argument("--version", "-v", action="version", version="tensorflow2fluid version=0.0.1 Release @2019.01.28")
return parser
def _convert(args):
......@@ -39,14 +40,22 @@ def _convert(args):
items[i] = int(items[i])
input_shape.append(items)
sys.stderr.write("\nLoading tensorflow model......\n")
if args.meta_file is not None:
parser = TensorflowCkptParser(args.meta_file, args.ckpt_dir, args.output_nodes, input_shape, args.in_nodes)
else:
parser = TensorflowPbParser(args.pb_file, args.output_nodes, input_shape, args.in_nodes)
sys.stderr.write("Tensorflow model loaded!\n")
emitter = PaddleEmitter(parser, args.save_dir)
emitter.run()
if __name__ == "__main__":
open(args.save_dir+"/__init__.py", "w").close()
def _main():
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
parser = _get_parser()
args = parser.parse_args()
_convert(args)
if __name__ == "__main__":
_main()
......@@ -66,6 +66,7 @@ class PaddleEmitter(object):
return axis
def export_weights(self, weight, paddle_var_name, dir):
self.save_var_set.add(paddle_var_name)
numpy_dtype_map = {
"int16": framework.VarType.INT16,
"int32": framework.VarType.INT32,
......@@ -119,13 +120,17 @@ class PaddleEmitter(object):
self.body_code += (self.tab * indent) + code + "\n"
def run(self):
print("new version")
node = self.graph.tf_graph.node[0]
self.add_codes(0, self.header_code)
self.save_var_set = set()
ref_name_recorder = open(self.save_dir + "/ref_name.txt", 'w')
total_nodes_num = len(self.graph.topological_sort)
translated_nodes_count = 1
sys.stderr.write("\nStart to translate all the nodes(Total_num:{})\n".
sys.stderr.write("\nModel Translating......\n")
sys.stderr.write("Start to translate all the nodes(Total_num:{})\n".
format(total_nodes_num))
for node in self.graph.topological_sort:
sys.stderr.write(
......@@ -167,6 +172,10 @@ class PaddleEmitter(object):
filew = open(self.save_dir + "/mymodel.py", 'w')
filew.write(self.body_code)
filew.close()
filew = open(self.save_dir + "/save_var.list", 'w')
for var in self.save_var_set:
filew.write(var + '\n')
filew.close()
sys.stderr.write("Model translated!\n\n")
sys.stderr.flush()
......
......@@ -18,7 +18,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
class TensorflowCkptParser(object):
def __init__(self,
meta_file,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册