未验证 提交 735bab4e 编写于 作者: J Jason 提交者: GitHub

新增vgg_16模型转换tutorial

上级 0c46d4de
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial : tensorflow2fluid转换VGG_16模型\n",
"\n",
"VGG_16是CV领域的一个经典模型,本文档以tensorflow/models下的[VGG_16](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)为例,展示如何将TensorFlow训练好的模型转换为PaddlePaddle模型。 \n",
"### 下载预训练模型"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Download percentage 100.00%"
]
}
],
"source": [
"import urllib\n",
"import sys\n",
"def schedule(a, b, c):\n",
" per = 100.0 * a * b / c\n",
" per = int(per)\n",
" sys.stderr.write(\"\\rDownload percentage %.2f%%\" % per)\n",
" sys.stderr.flush()\n",
"\n",
"url = \"http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\"\n",
"fetch = urllib.urlretrieve(url, \"./vgg_16.tar.gz\", schedule)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 解压下载的压缩文件"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tarfile\n",
"with tarfile.open(\"./vgg_16.tar.gz\", \"r:gz\") as f:\n",
" file_names = f.getnames()\n",
" for file_name in file_names:\n",
" f.extract(file_name, \"./\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 保存模型为checkpoint格式\n",
"\n",
"tensorflow2fluid目前支持checkpoint格式的模型或者是将网络结构和参数序列化的pb格式模型,上面下载的`vgg_16.ckpt`仅仅存储了模型参数,因此我们需要重新加载参数,并将网络结构和参数一起保存为checkpoint模型"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from vgg_16.ckpt\n"
]
}
],
"source": [
"import tensorflow.contrib.slim as slim\n",
"from tensorflow.contrib.slim.nets import vgg\n",
"import tensorflow as tf\n",
"import numpy\n",
"\n",
"with tf.Session() as sess:\n",
" inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name=\"inputs\")\n",
" logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)\n",
" load_model = slim.assign_from_checkpoint_fn(\"vgg_16.ckpt\", slim.get_model_variables(\"vgg_16\"))\n",
" load_model(sess)\n",
" \n",
" numpy.random.seed(13)\n",
" data = numpy.random.rand(5, 224, 224, 3)\n",
" input_tensor = sess.graph.get_tensor_by_name(\"inputs:0\")\n",
" output_tensor = sess.graph.get_tensor_by_name(\"vgg_16/fc8/squeezed:0\")\n",
" result = sess.run([output_tensor], {input_tensor:data})\n",
" numpy.save(\"tensorflow.npy\", numpy.array(result))\n",
" \n",
" saver = tf.train.Saver()\n",
" saver.save(sess, \"./checkpoint/model\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 将模型转换为PaddlePaddle模型\n",
"\n",
"注意:部分OP在转换时,需要将参数写入文件;或者是运行tensorflow模型进行infer,获取tensor值。两种情况下均会消耗一定的时间用于IO或计算,对于后一种情况,建议转换模型时将`use_cuda`参数设为`True`,加快infer速度"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Loading tensorflow model...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from checkpoint/model\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from checkpoint/model\n",
"INFO:root:Tensorflow model loaded!\n",
"INFO:root:TotalNum:86,TraslatedNum:1,CurrentNode:inputs\n",
"INFO:root:TotalNum:86,TraslatedNum:2,CurrentNode:vgg_16/conv1/conv1_1/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:3,CurrentNode:vgg_16/conv1/conv1_1/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:4,CurrentNode:vgg_16/conv1/conv1_2/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:5,CurrentNode:vgg_16/conv1/conv1_2/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:6,CurrentNode:vgg_16/conv2/conv2_1/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:7,CurrentNode:vgg_16/conv2/conv2_1/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:8,CurrentNode:vgg_16/conv2/conv2_2/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:9,CurrentNode:vgg_16/conv2/conv2_2/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:10,CurrentNode:vgg_16/conv3/conv3_1/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:11,CurrentNode:vgg_16/conv3/conv3_1/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:12,CurrentNode:vgg_16/conv3/conv3_2/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:13,CurrentNode:vgg_16/conv3/conv3_2/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:14,CurrentNode:vgg_16/conv3/conv3_3/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:15,CurrentNode:vgg_16/conv3/conv3_3/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:16,CurrentNode:vgg_16/conv4/conv4_1/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:17,CurrentNode:vgg_16/conv4/conv4_1/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:18,CurrentNode:vgg_16/conv4/conv4_2/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:19,CurrentNode:vgg_16/conv4/conv4_2/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:20,CurrentNode:vgg_16/conv4/conv4_3/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:21,CurrentNode:vgg_16/conv4/conv4_3/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:22,CurrentNode:vgg_16/conv5/conv5_1/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:23,CurrentNode:vgg_16/conv5/conv5_1/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:24,CurrentNode:vgg_16/conv5/conv5_2/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:25,CurrentNode:vgg_16/conv5/conv5_2/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:26,CurrentNode:vgg_16/conv5/conv5_3/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:27,CurrentNode:vgg_16/conv5/conv5_3/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:28,CurrentNode:vgg_16/fc6/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:29,CurrentNode:vgg_16/fc6/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:30,CurrentNode:vgg_16/fc7/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:31,CurrentNode:vgg_16/fc7/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:32,CurrentNode:vgg_16/fc8/weights\n",
"INFO:root:TotalNum:86,TraslatedNum:33,CurrentNode:vgg_16/fc8/biases\n",
"INFO:root:TotalNum:86,TraslatedNum:34,CurrentNode:vgg_16/conv1/conv1_1/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:35,CurrentNode:vgg_16/conv1/conv1_1/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:36,CurrentNode:vgg_16/conv1/conv1_1/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:37,CurrentNode:vgg_16/conv1/conv1_2/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:38,CurrentNode:vgg_16/conv1/conv1_2/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:39,CurrentNode:vgg_16/conv1/conv1_2/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:40,CurrentNode:vgg_16/pool1/MaxPool\n",
"INFO:root:TotalNum:86,TraslatedNum:41,CurrentNode:vgg_16/conv2/conv2_1/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:42,CurrentNode:vgg_16/conv2/conv2_1/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:43,CurrentNode:vgg_16/conv2/conv2_1/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:44,CurrentNode:vgg_16/conv2/conv2_2/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:45,CurrentNode:vgg_16/conv2/conv2_2/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:46,CurrentNode:vgg_16/conv2/conv2_2/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:47,CurrentNode:vgg_16/pool2/MaxPool\n",
"INFO:root:TotalNum:86,TraslatedNum:48,CurrentNode:vgg_16/conv3/conv3_1/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:49,CurrentNode:vgg_16/conv3/conv3_1/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:50,CurrentNode:vgg_16/conv3/conv3_1/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:51,CurrentNode:vgg_16/conv3/conv3_2/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:52,CurrentNode:vgg_16/conv3/conv3_2/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:53,CurrentNode:vgg_16/conv3/conv3_2/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:54,CurrentNode:vgg_16/conv3/conv3_3/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:55,CurrentNode:vgg_16/conv3/conv3_3/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:56,CurrentNode:vgg_16/conv3/conv3_3/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:57,CurrentNode:vgg_16/pool3/MaxPool\n",
"INFO:root:TotalNum:86,TraslatedNum:58,CurrentNode:vgg_16/conv4/conv4_1/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:59,CurrentNode:vgg_16/conv4/conv4_1/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:60,CurrentNode:vgg_16/conv4/conv4_1/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:61,CurrentNode:vgg_16/conv4/conv4_2/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:62,CurrentNode:vgg_16/conv4/conv4_2/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:63,CurrentNode:vgg_16/conv4/conv4_2/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:64,CurrentNode:vgg_16/conv4/conv4_3/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:65,CurrentNode:vgg_16/conv4/conv4_3/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:66,CurrentNode:vgg_16/conv4/conv4_3/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:67,CurrentNode:vgg_16/pool4/MaxPool\n",
"INFO:root:TotalNum:86,TraslatedNum:68,CurrentNode:vgg_16/conv5/conv5_1/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:69,CurrentNode:vgg_16/conv5/conv5_1/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:70,CurrentNode:vgg_16/conv5/conv5_1/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:71,CurrentNode:vgg_16/conv5/conv5_2/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:72,CurrentNode:vgg_16/conv5/conv5_2/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:73,CurrentNode:vgg_16/conv5/conv5_2/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:74,CurrentNode:vgg_16/conv5/conv5_3/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:75,CurrentNode:vgg_16/conv5/conv5_3/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:76,CurrentNode:vgg_16/conv5/conv5_3/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:77,CurrentNode:vgg_16/pool5/MaxPool\n",
"INFO:root:TotalNum:86,TraslatedNum:78,CurrentNode:vgg_16/fc6/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:79,CurrentNode:vgg_16/fc6/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:80,CurrentNode:vgg_16/fc6/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:81,CurrentNode:vgg_16/fc7/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:82,CurrentNode:vgg_16/fc7/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:83,CurrentNode:vgg_16/fc7/Relu\n",
"INFO:root:TotalNum:86,TraslatedNum:84,CurrentNode:vgg_16/fc8/Conv2D\n",
"INFO:root:TotalNum:86,TraslatedNum:85,CurrentNode:vgg_16/fc8/BiasAdd\n",
"INFO:root:TotalNum:86,TraslatedNum:86,CurrentNode:vgg_16/fc8/squeezed\n",
"INFO:root:Model translated!\n"
]
}
],
"source": [
"import tf2fluid.convert as convert\n",
"import argparse\n",
"parser = convert._get_parser()\n",
"parser.meta_file = \"checkpoint/model.meta\"\n",
"parser.ckpt_dir = \"checkpoint\"\n",
"parser.in_nodes = [\"inputs\"]\n",
"parser.input_shape = [\"None,224,224,3\"]\n",
"parser.output_nodes = [\"vgg_16/fc8/squeezed\"]\n",
"parser.use_cuda = \"True\"\n",
"parser.input_format = \"NHWC\"\n",
"parser.save_dir = \"paddle_model\"\n",
"\n",
"convert.run(parser)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 加载转换后的PaddlePaddle模型,并进行预测\n",
"**需要注意,转换后的PaddlePaddle CV模型输入格式为NCHW**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-14T10:18:46.124339Z",
"start_time": "2019-03-14T10:18:40.858372Z"
}
},
"outputs": [],
"source": [
"import numpy\n",
"import tf2fluid.model_loader as ml\n",
"\n",
"model = ml.ModelLoader(\"paddle_model\", use_cuda=True)\n",
"\n",
"numpy.random.seed(13)\n",
"data = numpy.random.rand(5, 224, 224, 3).astype(\"float32\")\n",
"# NHWC -> NCHW\n",
"data = numpy.transpose(data, (0, 3, 1, 2))\n",
"\n",
"results = model.inference(feed_dict={model.inputs[0]:data})\n",
"\n",
"numpy.save(\"paddle.npy\", numpy.array(results))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 对比转换前后模型之前的预测结果diff"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-03-14T10:20:13.611132Z",
"start_time": "2019-03-14T10:20:13.598874Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.33786e-06\n"
]
}
],
"source": [
"import numpy\n",
"paddle_result = numpy.load(\"paddle.npy\")\n",
"tensorflow_result = numpy.load(\"tensorflow.npy\")\n",
"diff = numpy.fabs(paddle_result - tensorflow_result)\n",
"print(numpy.max(diff))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:CPU-Paddle]",
"language": "python",
"name": "conda-env-CPU-Paddle-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.15"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册