tensorflow2paddle.ipynb 6.0 KB
Notebook
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# X2Paddle快速上手——TensorFlow迁移至PaddlePaddle\n",
    "***X2Paddle简介***:X2Paddle支持将Caffe/TensorFlow/ONNX/PyTorch深度学习框架训练得到的模型,迁移至PaddlePaddle模型。   \n",
    "***X2Paddle代码GitHub链接***:[https://github.com/PaddlePaddle/X2Paddle](https://github.com/PaddlePaddle/X2Paddle)  \n",
    "***【注意】***前往GitHub给[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)点击Star,关注项目,即可随时了解X2Paddle的最新进展。  \n",
    "本教程用于帮助用户学习将TensorFlow训练后的预测模型迁移至PaddlePaddle框架,以TensorFlow版本的[MobileNetV1](https://github.com/tensorflow/models/tree/master/research/slim)为例进行详细介绍。  \n",
    "\n",
    "## 安装及准备\n",
    "### 1. 安装X2Paddle\n",
    "***方式一:(推荐)***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! git clone https://github.com/PaddlePaddle/X2Paddle.git\n",
    "! cd X2Paddle\n",
    "! git checkout develop\n",
    "! python setup.py install"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***方式二:***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install x2paddle==1.0.1 --index https://pypi.Python.org/simple/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 安装TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install tensorflow==1.14.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 安装PaddlePaddle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install paddlepaddle==2.0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型迁移\n",
    "### 1. 获取MobileNetV1的FrozenModel\n",
    "由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! wget http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz\n",
    "! tar zxvf mobilenet_v1_0.25_128.tgz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 转换\n",
    "需要传入的参数如下:\n",
    "> --framework (-f):源模型类型,此处设置为tensorflow。  \n",
    "> --save_dir (-s):指定转换后的模型保存目录路径。  \n",
    "> --model (-m):指定tensorflow的pb模型。  \n",
    "> --paddle_type (-pt):指定转换为动态图代码(dygraph)或者静态图代码(static),默认为dygraph。  \n",
    "\n",
    "***方式一:***生成静态图代码,并保存成静态图预测模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! x2paddle -f tensorflow -m ./mobilenet_v1_0.25_128_frozen.pb -s pd_model_static -pt static"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***方式二:***生成动态图代码,并保存成静态图预测模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! x2paddle -f tensorflow -m ./mobilenet_v1_0.25_128_frozen.pb  -s pd_model_dygraph -pt dygraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用方式一转换得到的PaddlePaddle预测模型进行预测:  \n",
    "(1)下载ImageNet类别文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(2)预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造输入\n",
    "import cv2\n",
    "import numpy as np\n",
    "img = cv2.imread(\"dog_tf.png\").astype(\"float32\") / 255.0\n",
    "img = np.expand_dims(img, 0)\n",
    "img -= 0.5\n",
    "img *= 2.0\n",
    "    \n",
    "# 进行预测\n",
    "import paddle\n",
    "import numpy as np\n",
    "paddle.enable_static()\n",
    "exe = paddle.static.Executor(paddle.CPUPlace())\n",
    "[prog, inputs, outputs] = paddle.static.load_inference_model(path_prefix=\"pd_model_static/inference_model\", \n",
    "                                                             executor=exe, \n",
    "                                                             model_filename=\"model.pdmodel\",\n",
    "                                                             params_filename=\"model.pdiparams\")\n",
    "result = exe.run(prog, feed={inputs[0]: img}, fetch_list=outputs)\n",
    "max_index = np.argmax(result[0])\n",
    "with open('imagenet_classes.txt') as f:\n",
    "    classes = [line.strip() for line in f.readlines()]\n",
    "print(\"The category of dog.jpg is: {}\".format(classes[max_index]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "all",
   "language": "python",
   "name": "all"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}