export_tf_model.md 4.0 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
## 如何导出TensorFlow模型

本文档介绍如何将TensorFlow模型导出为X2Paddle支持的模型格式。
TensorFlow目前一般分为3种保存格式(checkpoint、FrozenModel、SavedModel),X2Paddle支持的是FrozenModel(将网络参数和网络结构同时保存到同一个文件中,并且只保存指定的前向计算子图),下面示例展示了如何导出X2Paddle支持的模型格式。

***下列代码Tensorflow 1.X下使用***
- checkpoint模型+代码
步骤一 下载模型参数文件
```
wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
tar xzvf vgg_16_2016_08_28.tar.gz
```

步骤二 加载和导出模型
```
#coding: utf-8
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import vgg
from tensorflow.python.framework import graph_util
import tensorflow as tf

# 固化模型函数
# output_tensor_names: list,指定模型的输出tensor的name
# freeze_model_path: 模型导出的文件路径
def freeze_model(sess, output_tensor_names, freeze_model_path):
    out_graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_tensor_names)
    with tf.gfile.GFile(freeze_model_path, 'wb') as f:
        f.write(out_graph.SerializeToString())

    print("freeze model saved in {}".format(freeze_model_path))

# 加载模型参数
sess = tf.Session()
inputs = tf.placeholder(dtype=tf.float32,
                        shape=[None, 224, 224, 3],
                        name="inputs")
logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)
load_model = slim.assign_from_checkpoint_fn(
    "vgg_16.ckpt", slim.get_model_variables("vgg_16"))
load_model(sess)

# 导出模型
freeze_model(sess, ["vgg_16/fc8/squeezed"], "vgg16.pb")
```
- 纯checkpoint模型
文件结构:
S
SunAhong1993 已提交
48 49 50 51
> |--- checkpoint  
> |--- model.ckpt-240000.data-00000-of-00001  
> |--- model.ckpt-240000.index  
> |--- model.ckpt-240000.meta  
S
SunAhong1993 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

加载和导出模型:
```python
#coding: utf-8
from tensorflow.python.framework import graph_util
import tensorflow as tf

# 固化模型函数
# output_tensor_names: list,指定模型的输出tensor的name
# freeze_model_path: 模型导出的文件路径
def freeze_model(sess, output_tensor_names, freeze_model_path):
    out_graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_tensor_names)
    with tf.gfile.GFile(freeze_model_path, 'wb') as f:
        f.write(out_graph.SerializeToString())

    print("freeze model saved in {}".format(freeze_model_path))
    
# 加载模型参数
# 此处需要修改input_checkpoint(checkpoint的前缀)和save_pb_file(模型导出的文件路径)
input_checkpoint = "./tfhub_models/save/model.ckpt"
save_pb_file = "./tfhub_models/save.pb"
sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
saver.restore(sess, input_checkpoint)

# 此处需要修改freeze_model的第二个参数,指定模型的输出tensor的name
freeze_model(sess, ["vgg_16/fc8/squeezed"], save_pb_file)
```

- SavedModel模型
文件结构:
S
SunAhong1993 已提交
84 85 86 87
> |-- variables  
> |------ variables.data-00000-of-00001  
> |------ variables.data-00000-of-00001  
> |-- saved_model.pb  
S
SunAhong1993 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111

加载和导出模型:
```python
#coding: utf-8
import tensorflow as tf
sess = tf.Session(graph=tf.Graph())
# tf.saved_model.loader.load最后一个参数代表saved_model的保存路径
tf.saved_model.loader.load(sess, {}, "/mnt/saved_model")
graph = tf.get_default_graph()

from tensorflow.python.framework import graph_util
# 固化模型函数
# output_tensor_names: list,指定模型的输出tensor的name
# freeze_model_path: 模型导出的文件路径
def freeze_model(sess, output_tensor_names, freeze_model_path):
    out_graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), output_tensor_names)
    with tf.gfile.GFile(freeze_model_path, 'wb') as f:
        f.write(out_graph.SerializeToString())

    print("freeze model saved in {}".format(freeze_model_path))

# 导出模型
freeze_model(sess, ["logits"], "model.pb")
S
SunAhong1993 已提交
112
```