diff --git a/tensorflow2fluid/vgg_translate_tutorial.ipynb b/tensorflow2fluid/vgg_translate_tutorial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7946cdb56ebb2bff132fb9c5d308443b7f5f4974 --- /dev/null +++ b/tensorflow2fluid/vgg_translate_tutorial.ipynb @@ -0,0 +1,369 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial : tensorflow2fluid转换VGG_16模型\n", + "\n", + "VGG_16是CV领域的一个经典模型,本文档以tensorflow/models下的[VGG_16](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)为例,展示如何将TensorFlow训练好的模型转换为PaddlePaddle模型。 \n", + "### 下载预训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Download percentage 100.00%" + ] + } + ], + "source": [ + "import urllib\n", + "import sys\n", + "def schedule(a, b, c):\n", + " per = 100.0 * a * b / c\n", + " per = int(per)\n", + " sys.stderr.write(\"\\rDownload percentage %.2f%%\" % per)\n", + " sys.stderr.flush()\n", + "\n", + "url = \"http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz\"\n", + "fetch = urllib.urlretrieve(url, \"./vgg_16.tar.gz\", schedule)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 解压下载的压缩文件" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import tarfile\n", + "with tarfile.open(\"./vgg_16.tar.gz\", \"r:gz\") as f:\n", + " file_names = f.getnames()\n", + " for file_name in file_names:\n", + " f.extract(file_name, \"./\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 保存模型为checkpoint格式\n", + "\n", + "tensorflow2fluid目前支持checkpoint格式的模型或者是将网络结构和参数序列化的pb格式模型,上面下载的`vgg_16.ckpt`仅仅存储了模型参数,因此我们需要重新加载参数,并将网络结构和参数一起保存为checkpoint模型" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from vgg_16.ckpt\n" + ] + } + ], + "source": [ + "import tensorflow.contrib.slim as slim\n", + "from tensorflow.contrib.slim.nets import vgg\n", + "import tensorflow as tf\n", + "import numpy\n", + "\n", + "with tf.Session() as sess:\n", + " inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name=\"inputs\")\n", + " logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)\n", + " load_model = slim.assign_from_checkpoint_fn(\"vgg_16.ckpt\", slim.get_model_variables(\"vgg_16\"))\n", + " load_model(sess)\n", + " \n", + " numpy.random.seed(13)\n", + " data = numpy.random.rand(5, 224, 224, 3)\n", + " input_tensor = sess.graph.get_tensor_by_name(\"inputs:0\")\n", + " output_tensor = sess.graph.get_tensor_by_name(\"vgg_16/fc8/squeezed:0\")\n", + " result = sess.run([output_tensor], {input_tensor:data})\n", + " numpy.save(\"tensorflow.npy\", numpy.array(result))\n", + " \n", + " saver = tf.train.Saver()\n", + " saver.save(sess, \"./checkpoint/model\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 将模型转换为PaddlePaddle模型\n", + "\n", + "注意:部分OP在转换时,需要将参数写入文件;或者是运行tensorflow模型进行infer,获取tensor值。两种情况下均会消耗一定的时间用于IO或计算,对于后一种情况,建议转换模型时将`use_cuda`参数设为`True`,加快infer速度" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading tensorflow model...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from checkpoint/model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from checkpoint/model\n", + "INFO:root:Tensorflow model loaded!\n", + "INFO:root:TotalNum:86,TraslatedNum:1,CurrentNode:inputs\n", + "INFO:root:TotalNum:86,TraslatedNum:2,CurrentNode:vgg_16/conv1/conv1_1/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:3,CurrentNode:vgg_16/conv1/conv1_1/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:4,CurrentNode:vgg_16/conv1/conv1_2/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:5,CurrentNode:vgg_16/conv1/conv1_2/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:6,CurrentNode:vgg_16/conv2/conv2_1/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:7,CurrentNode:vgg_16/conv2/conv2_1/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:8,CurrentNode:vgg_16/conv2/conv2_2/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:9,CurrentNode:vgg_16/conv2/conv2_2/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:10,CurrentNode:vgg_16/conv3/conv3_1/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:11,CurrentNode:vgg_16/conv3/conv3_1/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:12,CurrentNode:vgg_16/conv3/conv3_2/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:13,CurrentNode:vgg_16/conv3/conv3_2/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:14,CurrentNode:vgg_16/conv3/conv3_3/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:15,CurrentNode:vgg_16/conv3/conv3_3/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:16,CurrentNode:vgg_16/conv4/conv4_1/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:17,CurrentNode:vgg_16/conv4/conv4_1/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:18,CurrentNode:vgg_16/conv4/conv4_2/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:19,CurrentNode:vgg_16/conv4/conv4_2/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:20,CurrentNode:vgg_16/conv4/conv4_3/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:21,CurrentNode:vgg_16/conv4/conv4_3/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:22,CurrentNode:vgg_16/conv5/conv5_1/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:23,CurrentNode:vgg_16/conv5/conv5_1/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:24,CurrentNode:vgg_16/conv5/conv5_2/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:25,CurrentNode:vgg_16/conv5/conv5_2/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:26,CurrentNode:vgg_16/conv5/conv5_3/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:27,CurrentNode:vgg_16/conv5/conv5_3/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:28,CurrentNode:vgg_16/fc6/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:29,CurrentNode:vgg_16/fc6/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:30,CurrentNode:vgg_16/fc7/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:31,CurrentNode:vgg_16/fc7/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:32,CurrentNode:vgg_16/fc8/weights\n", + "INFO:root:TotalNum:86,TraslatedNum:33,CurrentNode:vgg_16/fc8/biases\n", + "INFO:root:TotalNum:86,TraslatedNum:34,CurrentNode:vgg_16/conv1/conv1_1/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:35,CurrentNode:vgg_16/conv1/conv1_1/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:36,CurrentNode:vgg_16/conv1/conv1_1/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:37,CurrentNode:vgg_16/conv1/conv1_2/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:38,CurrentNode:vgg_16/conv1/conv1_2/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:39,CurrentNode:vgg_16/conv1/conv1_2/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:40,CurrentNode:vgg_16/pool1/MaxPool\n", + "INFO:root:TotalNum:86,TraslatedNum:41,CurrentNode:vgg_16/conv2/conv2_1/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:42,CurrentNode:vgg_16/conv2/conv2_1/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:43,CurrentNode:vgg_16/conv2/conv2_1/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:44,CurrentNode:vgg_16/conv2/conv2_2/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:45,CurrentNode:vgg_16/conv2/conv2_2/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:46,CurrentNode:vgg_16/conv2/conv2_2/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:47,CurrentNode:vgg_16/pool2/MaxPool\n", + "INFO:root:TotalNum:86,TraslatedNum:48,CurrentNode:vgg_16/conv3/conv3_1/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:49,CurrentNode:vgg_16/conv3/conv3_1/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:50,CurrentNode:vgg_16/conv3/conv3_1/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:51,CurrentNode:vgg_16/conv3/conv3_2/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:52,CurrentNode:vgg_16/conv3/conv3_2/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:53,CurrentNode:vgg_16/conv3/conv3_2/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:54,CurrentNode:vgg_16/conv3/conv3_3/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:55,CurrentNode:vgg_16/conv3/conv3_3/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:56,CurrentNode:vgg_16/conv3/conv3_3/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:57,CurrentNode:vgg_16/pool3/MaxPool\n", + "INFO:root:TotalNum:86,TraslatedNum:58,CurrentNode:vgg_16/conv4/conv4_1/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:59,CurrentNode:vgg_16/conv4/conv4_1/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:60,CurrentNode:vgg_16/conv4/conv4_1/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:61,CurrentNode:vgg_16/conv4/conv4_2/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:62,CurrentNode:vgg_16/conv4/conv4_2/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:63,CurrentNode:vgg_16/conv4/conv4_2/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:64,CurrentNode:vgg_16/conv4/conv4_3/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:65,CurrentNode:vgg_16/conv4/conv4_3/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:66,CurrentNode:vgg_16/conv4/conv4_3/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:67,CurrentNode:vgg_16/pool4/MaxPool\n", + "INFO:root:TotalNum:86,TraslatedNum:68,CurrentNode:vgg_16/conv5/conv5_1/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:69,CurrentNode:vgg_16/conv5/conv5_1/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:70,CurrentNode:vgg_16/conv5/conv5_1/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:71,CurrentNode:vgg_16/conv5/conv5_2/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:72,CurrentNode:vgg_16/conv5/conv5_2/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:73,CurrentNode:vgg_16/conv5/conv5_2/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:74,CurrentNode:vgg_16/conv5/conv5_3/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:75,CurrentNode:vgg_16/conv5/conv5_3/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:76,CurrentNode:vgg_16/conv5/conv5_3/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:77,CurrentNode:vgg_16/pool5/MaxPool\n", + "INFO:root:TotalNum:86,TraslatedNum:78,CurrentNode:vgg_16/fc6/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:79,CurrentNode:vgg_16/fc6/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:80,CurrentNode:vgg_16/fc6/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:81,CurrentNode:vgg_16/fc7/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:82,CurrentNode:vgg_16/fc7/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:83,CurrentNode:vgg_16/fc7/Relu\n", + "INFO:root:TotalNum:86,TraslatedNum:84,CurrentNode:vgg_16/fc8/Conv2D\n", + "INFO:root:TotalNum:86,TraslatedNum:85,CurrentNode:vgg_16/fc8/BiasAdd\n", + "INFO:root:TotalNum:86,TraslatedNum:86,CurrentNode:vgg_16/fc8/squeezed\n", + "INFO:root:Model translated!\n" + ] + } + ], + "source": [ + "import tf2fluid.convert as convert\n", + "import argparse\n", + "parser = convert._get_parser()\n", + "parser.meta_file = \"checkpoint/model.meta\"\n", + "parser.ckpt_dir = \"checkpoint\"\n", + "parser.in_nodes = [\"inputs\"]\n", + "parser.input_shape = [\"None,224,224,3\"]\n", + "parser.output_nodes = [\"vgg_16/fc8/squeezed\"]\n", + "parser.use_cuda = \"True\"\n", + "parser.input_format = \"NHWC\"\n", + "parser.save_dir = \"paddle_model\"\n", + "\n", + "convert.run(parser)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 加载转换后的PaddlePaddle模型,并进行预测\n", + "**需要注意,转换后的PaddlePaddle CV模型输入格式为NCHW**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2019-03-14T10:18:46.124339Z", + "start_time": "2019-03-14T10:18:40.858372Z" + } + }, + "outputs": [], + "source": [ + "import numpy\n", + "import tf2fluid.model_loader as ml\n", + "\n", + "model = ml.ModelLoader(\"paddle_model\", use_cuda=True)\n", + "\n", + "numpy.random.seed(13)\n", + "data = numpy.random.rand(5, 224, 224, 3).astype(\"float32\")\n", + "# NHWC -> NCHW\n", + "data = numpy.transpose(data, (0, 3, 1, 2))\n", + "\n", + "results = model.inference(feed_dict={model.inputs[0]:data})\n", + "\n", + "numpy.save(\"paddle.npy\", numpy.array(results))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对比转换前后模型之前的预测结果diff" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2019-03-14T10:20:13.611132Z", + "start_time": "2019-03-14T10:20:13.598874Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.33786e-06\n" + ] + } + ], + "source": [ + "import numpy\n", + "paddle_result = numpy.load(\"paddle.npy\")\n", + "tensorflow_result = numpy.load(\"tensorflow.npy\")\n", + "diff = numpy.fabs(paddle_result - tensorflow_result)\n", + "print(numpy.max(diff))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:CPU-Paddle]", + "language": "python", + "name": "conda-env-CPU-Paddle-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.15" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}