{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial : tensorflow2fluid转换VGG_16模型\n", "\n", "VGG_16是CV领域的一个经典模型,本文档以tensorflow/models下的[VGG_16](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)为例,展示如何将TensorFlow训练好的模型转换为PaddlePaddle模型。 \n", "### 下载预训练模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "module 'urllib' has no attribute 'urlretrieve'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0murl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mfetch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murllib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0murlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"./vgg_16.tar.gz\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschedule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m: module 'urllib' has no attribute 'urlretrieve'" ] } ], "source": [ "import urllib\n", "import sys\n", "def schedule(a, b, c):\n", " per = 100.0 * a * b / c\n", " per = int(per)\n", " sys.stderr.write(\"\\rDownload percentage %.2f%%\" % per)\n", " sys.stderr.flush()\n", "\n", "url = \"http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\"\n", "fetch = urllib.urlretrieve(url, \"./vgg_16.tar.gz\", schedule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 解压下载的压缩文件" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import tarfile\n", "with tarfile.open(\"./vgg_16.tar.gz\", \"r:gz\") as f:\n", " file_names = f.getnames()\n", " for file_name in file_names:\n", " f.extract(file_name, \"./\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 保存模型为checkpoint格式\n", "\n", "tensorflow2fluid目前支持checkpoint格式的模型或者是将网络结构和参数序列化的pb格式模型,上面下载的`vgg_16.ckpt`仅仅存储了模型参数,因此我们需要重新加载参数,并将网络结构和参数一起保存为checkpoint模型\n", "\n", "**注意:下面的代码里,运行TensorFlow模型和将TensorFlow模型转换为PaddlePaddle模型,依赖TensorFlow**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from vgg_16.ckpt\n" ] } ], "source": [ "import tensorflow.contrib.slim as slim\n", "from tensorflow.contrib.slim.nets import vgg\n", "import tensorflow as tf\n", "import numpy\n", "\n", "with tf.Session() as sess:\n", " inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name=\"inputs\")\n", " logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)\n", " load_model = slim.assign_from_checkpoint_fn(\"vgg_16.ckpt\", slim.get_model_variables(\"vgg_16\"))\n", " load_model(sess)\n", " \n", " numpy.random.seed(13)\n", " data = numpy.random.rand(5, 224, 224, 3)\n", " input_tensor = sess.graph.get_tensor_by_name(\"inputs:0\")\n", " output_tensor = sess.graph.get_tensor_by_name(\"vgg_16/fc8/squeezed:0\")\n", " result = sess.run([output_tensor], {input_tensor:data})\n", " numpy.save(\"tensorflow.npy\", numpy.array(result))\n", " \n", " saver = tf.train.Saver()\n", " saver.save(sess, \"./checkpoint/model\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 将模型转换为PaddlePaddle模型\n", "\n", "**注意**:部分OP在转换时,需要将参数写入文件;或者是运行tensorflow模型进行infer,获取tensor值。两种情况下均会消耗一定的时间用于IO或计算,对于后一种情况,建议转换模型时将`use_cuda`参数设为`True`,加快infer速度\n", "\n", "可以通过下面的**模型转换python脚本**在代码中设置参数,在python脚本中进行模型转换。或者一般可以通过如下的命令行方式进行转换,\n", "``` python\n", "# 通过命令行也可进行模型转换\n", "python tf2fluid/convert.py --meta_file checkpoint/model.meta --ckpt_dir checkpoint \\\n", " --in_nodes inputs --input_shape None,224,224,3 \\\n", " --output_nodes vgg_16/fc8/squeezed --use_cuda True \\\n", " --input_format NHWC --save_dir paddle_model\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 模型转换python脚本" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Loading tensorflow model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from checkpoint/model\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from checkpoint/model\n", "INFO:root:Tensorflow model loaded!\n", "INFO:root:TotalNum:86,TraslatedNum:1,CurrentNode:inputs\n", "INFO:root:TotalNum:86,TraslatedNum:2,CurrentNode:vgg_16/conv1/conv1_1/weights\n", "INFO:root:TotalNum:86,TraslatedNum:3,CurrentNode:vgg_16/conv1/conv1_1/biases\n", "INFO:root:TotalNum:86,TraslatedNum:4,CurrentNode:vgg_16/conv1/conv1_2/weights\n", "INFO:root:TotalNum:86,TraslatedNum:5,CurrentNode:vgg_16/conv1/conv1_2/biases\n", "INFO:root:TotalNum:86,TraslatedNum:6,CurrentNode:vgg_16/conv2/conv2_1/weights\n", "INFO:root:TotalNum:86,TraslatedNum:7,CurrentNode:vgg_16/conv2/conv2_1/biases\n", "INFO:root:TotalNum:86,TraslatedNum:8,CurrentNode:vgg_16/conv2/conv2_2/weights\n", "INFO:root:TotalNum:86,TraslatedNum:9,CurrentNode:vgg_16/conv2/conv2_2/biases\n", "INFO:root:TotalNum:86,TraslatedNum:10,CurrentNode:vgg_16/conv3/conv3_1/weights\n", "INFO:root:TotalNum:86,TraslatedNum:11,CurrentNode:vgg_16/conv3/conv3_1/biases\n", "INFO:root:TotalNum:86,TraslatedNum:12,CurrentNode:vgg_16/conv3/conv3_2/weights\n", "INFO:root:TotalNum:86,TraslatedNum:13,CurrentNode:vgg_16/conv3/conv3_2/biases\n", "INFO:root:TotalNum:86,TraslatedNum:14,CurrentNode:vgg_16/conv3/conv3_3/weights\n", "INFO:root:TotalNum:86,TraslatedNum:15,CurrentNode:vgg_16/conv3/conv3_3/biases\n", "INFO:root:TotalNum:86,TraslatedNum:16,CurrentNode:vgg_16/conv4/conv4_1/weights\n", "INFO:root:TotalNum:86,TraslatedNum:17,CurrentNode:vgg_16/conv4/conv4_1/biases\n", "INFO:root:TotalNum:86,TraslatedNum:18,CurrentNode:vgg_16/conv4/conv4_2/weights\n", "INFO:root:TotalNum:86,TraslatedNum:19,CurrentNode:vgg_16/conv4/conv4_2/biases\n", "INFO:root:TotalNum:86,TraslatedNum:20,CurrentNode:vgg_16/conv4/conv4_3/weights\n", "INFO:root:TotalNum:86,TraslatedNum:21,CurrentNode:vgg_16/conv4/conv4_3/biases\n", "INFO:root:TotalNum:86,TraslatedNum:22,CurrentNode:vgg_16/conv5/conv5_1/weights\n", "INFO:root:TotalNum:86,TraslatedNum:23,CurrentNode:vgg_16/conv5/conv5_1/biases\n", "INFO:root:TotalNum:86,TraslatedNum:24,CurrentNode:vgg_16/conv5/conv5_2/weights\n", "INFO:root:TotalNum:86,TraslatedNum:25,CurrentNode:vgg_16/conv5/conv5_2/biases\n", "INFO:root:TotalNum:86,TraslatedNum:26,CurrentNode:vgg_16/conv5/conv5_3/weights\n", "INFO:root:TotalNum:86,TraslatedNum:27,CurrentNode:vgg_16/conv5/conv5_3/biases\n", "INFO:root:TotalNum:86,TraslatedNum:28,CurrentNode:vgg_16/fc6/weights\n", "INFO:root:TotalNum:86,TraslatedNum:29,CurrentNode:vgg_16/fc6/biases\n", "INFO:root:TotalNum:86,TraslatedNum:30,CurrentNode:vgg_16/fc7/weights\n", "INFO:root:TotalNum:86,TraslatedNum:31,CurrentNode:vgg_16/fc7/biases\n", "INFO:root:TotalNum:86,TraslatedNum:32,CurrentNode:vgg_16/fc8/weights\n", "INFO:root:TotalNum:86,TraslatedNum:33,CurrentNode:vgg_16/fc8/biases\n", "INFO:root:TotalNum:86,TraslatedNum:34,CurrentNode:vgg_16/conv1/conv1_1/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:35,CurrentNode:vgg_16/conv1/conv1_1/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:36,CurrentNode:vgg_16/conv1/conv1_1/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:37,CurrentNode:vgg_16/conv1/conv1_2/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:38,CurrentNode:vgg_16/conv1/conv1_2/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:39,CurrentNode:vgg_16/conv1/conv1_2/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:40,CurrentNode:vgg_16/pool1/MaxPool\n", "INFO:root:TotalNum:86,TraslatedNum:41,CurrentNode:vgg_16/conv2/conv2_1/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:42,CurrentNode:vgg_16/conv2/conv2_1/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:43,CurrentNode:vgg_16/conv2/conv2_1/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:44,CurrentNode:vgg_16/conv2/conv2_2/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:45,CurrentNode:vgg_16/conv2/conv2_2/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:46,CurrentNode:vgg_16/conv2/conv2_2/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:47,CurrentNode:vgg_16/pool2/MaxPool\n", "INFO:root:TotalNum:86,TraslatedNum:48,CurrentNode:vgg_16/conv3/conv3_1/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:49,CurrentNode:vgg_16/conv3/conv3_1/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:50,CurrentNode:vgg_16/conv3/conv3_1/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:51,CurrentNode:vgg_16/conv3/conv3_2/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:52,CurrentNode:vgg_16/conv3/conv3_2/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:53,CurrentNode:vgg_16/conv3/conv3_2/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:54,CurrentNode:vgg_16/conv3/conv3_3/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:55,CurrentNode:vgg_16/conv3/conv3_3/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:56,CurrentNode:vgg_16/conv3/conv3_3/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:57,CurrentNode:vgg_16/pool3/MaxPool\n", "INFO:root:TotalNum:86,TraslatedNum:58,CurrentNode:vgg_16/conv4/conv4_1/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:59,CurrentNode:vgg_16/conv4/conv4_1/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:60,CurrentNode:vgg_16/conv4/conv4_1/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:61,CurrentNode:vgg_16/conv4/conv4_2/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:62,CurrentNode:vgg_16/conv4/conv4_2/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:63,CurrentNode:vgg_16/conv4/conv4_2/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:64,CurrentNode:vgg_16/conv4/conv4_3/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:65,CurrentNode:vgg_16/conv4/conv4_3/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:66,CurrentNode:vgg_16/conv4/conv4_3/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:67,CurrentNode:vgg_16/pool4/MaxPool\n", "INFO:root:TotalNum:86,TraslatedNum:68,CurrentNode:vgg_16/conv5/conv5_1/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:69,CurrentNode:vgg_16/conv5/conv5_1/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:70,CurrentNode:vgg_16/conv5/conv5_1/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:71,CurrentNode:vgg_16/conv5/conv5_2/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:72,CurrentNode:vgg_16/conv5/conv5_2/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:73,CurrentNode:vgg_16/conv5/conv5_2/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:74,CurrentNode:vgg_16/conv5/conv5_3/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:75,CurrentNode:vgg_16/conv5/conv5_3/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:76,CurrentNode:vgg_16/conv5/conv5_3/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:77,CurrentNode:vgg_16/pool5/MaxPool\n", "INFO:root:TotalNum:86,TraslatedNum:78,CurrentNode:vgg_16/fc6/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:79,CurrentNode:vgg_16/fc6/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:80,CurrentNode:vgg_16/fc6/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:81,CurrentNode:vgg_16/fc7/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:82,CurrentNode:vgg_16/fc7/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:83,CurrentNode:vgg_16/fc7/Relu\n", "INFO:root:TotalNum:86,TraslatedNum:84,CurrentNode:vgg_16/fc8/Conv2D\n", "INFO:root:TotalNum:86,TraslatedNum:85,CurrentNode:vgg_16/fc8/BiasAdd\n", "INFO:root:TotalNum:86,TraslatedNum:86,CurrentNode:vgg_16/fc8/squeezed\n", "INFO:root:Model translated!\n" ] } ], "source": [ "import tf2fluid.convert as convert\n", "import argparse\n", "parser = convert._get_parser()\n", "parser.meta_file = \"checkpoint/model.meta\"\n", "parser.ckpt_dir = \"checkpoint\"\n", "parser.in_nodes = [\"inputs\"]\n", "parser.input_shape = [\"None,224,224,3\"]\n", "parser.output_nodes = [\"vgg_16/fc8/squeezed\"]\n", "parser.use_cuda = \"True\"\n", "parser.input_format = \"NHWC\"\n", "parser.save_dir = \"paddle_model\"\n", "\n", "convert.run(parser)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 加载转换后的PaddlePaddle模型,并进行预测\n", "需要注意的是,转换后的PaddlePaddle CV模型**输入格式为NCHW**\n", "\n", "**注意:下面代码用于运行转换后的PaddlePaddle模型,并与TensorFlow计算结果对比diff,因此依赖PaddlePaddle**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-03-15T05:51:40.544737Z", "start_time": "2019-03-15T05:51:27.857863Z" } }, "outputs": [], "source": [ "import numpy\n", "import tf2fluid.model_loader as ml\n", "\n", "model = ml.ModelLoader(\"paddle_model\", use_cuda=False)\n", "\n", "numpy.random.seed(13)\n", "data = numpy.random.rand(5, 224, 224, 3).astype(\"float32\")\n", "# NHWC -> NCHW\n", "data = numpy.transpose(data, (0, 3, 1, 2))\n", "\n", "results = model.inference(feed_dict={model.inputs[0]:data})\n", "\n", "numpy.save(\"paddle.npy\", numpy.array(results))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 对比转换前后模型之前的预测结果diff" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-03-15T05:52:02.126718Z", "start_time": "2019-03-15T05:52:02.115849Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.67572e-06\n" ] } ], "source": [ "import numpy\n", "paddle_result = numpy.load(\"paddle.npy\")\n", "tensorflow_result = numpy.load(\"tensorflow.npy\")\n", "diff = numpy.fabs(paddle_result - tensorflow_result)\n", "print(numpy.max(diff))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 需要注意的点\n", "1. 转换后的模型需要注意输入格式,PaddlePaddle中输入格式需为NCHW格式 \n", "2. 此例中不涉及到输入中间层,如卷积层的输出,需要了解的是PaddlePaddle中的卷积层输出,卷积核的`shape`与Tensorflow有差异 \n", "3. 模型转换完后,检查转换前后模型的diff,在本例中,测试得到的最大diff满足转换需求 " ] } ], "metadata": { "kernelspec": { "display_name": "tensorflow", "language": "python", "name": "tensorflow" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.6" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }