From e7e3881aa7e428f6743c48e184fbb9ab7b8787c8 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sun, 13 Jan 2019 18:40:00 +0800 Subject: [PATCH] add demo --- .../demo/export_resnet_to_paddle_model.py | 13 ++++++++ .../demo/save_resnet_ckpt_model.py | 31 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 TensorFlow2Paddle/demo/export_resnet_to_paddle_model.py create mode 100644 TensorFlow2Paddle/demo/save_resnet_ckpt_model.py diff --git a/TensorFlow2Paddle/demo/export_resnet_to_paddle_model.py b/TensorFlow2Paddle/demo/export_resnet_to_paddle_model.py new file mode 100644 index 0000000..0530a33 --- /dev/null +++ b/TensorFlow2Paddle/demo/export_resnet_to_paddle_model.py @@ -0,0 +1,13 @@ +import sys +sys.path.append(".") +from transformer import Transformer + +meta_file = sys.argv[1] +ckpt_dir = sys.argv[2] +export_dir = sys.argv[3] + +transformer = Transformer(meta_file, ckpt_dir, ['resnet_v1_50/pool5'], + (224, 224, 3), ['inputs']) +transformer.run(export_dir) + +open(export_dir + "/__init__.py", "w").close() diff --git a/TensorFlow2Paddle/demo/save_resnet_ckpt_model.py b/TensorFlow2Paddle/demo/save_resnet_ckpt_model.py new file mode 100644 index 0000000..610d2d2 --- /dev/null +++ b/TensorFlow2Paddle/demo/save_resnet_ckpt_model.py @@ -0,0 +1,31 @@ +from tensorflow.contrib.slim.nets import resnet_v1 as resnet_v1 +import tensorflow.contrib.slim as slim +import tensorflow as tf +import sys + + +def load_model(ckpt_file): + img_size = resnet_v1.resnet_v1.default_image_size + img = tf.placeholder( + tf.float32, shape=[None, img_size, img_size, 3], name='inputs') + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, endpoint = resnet_v1.resnet_v1_50( + img, num_classes=None, is_training=False) + + sess = tf.Session() + load_model = tf.contrib.slim.assign_from_checkpoint_fn( + ckpt_file, tf.contrib.slim.get_model_variables("resnet_v1_50")) + load_model(sess) + return sess + + +def save_checkpoint(sess, save_dir): + saver = tf.train.Saver() + saver.save(sess, save_dir + "/resnet") + + +if __name__ == "__main__": + ckpt_file = sys.argv[1] + save_dir = sys.argv[2] + sess = load_model(ckpt_file) + save_checkpoint(sess, save_dir) -- GitLab