{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# X2Paddle快速上手——TensorFlow迁移至PaddlePaddle\n", "***X2Paddle简介***:X2Paddle支持将Caffe/TensorFlow/ONNX/PyTorch深度学习框架训练得到的模型,迁移至PaddlePaddle模型。 \n", "***X2Paddle代码GitHub链接***:[https://github.com/PaddlePaddle/X2Paddle](https://github.com/PaddlePaddle/X2Paddle) \n", "***【注意】***前往GitHub给[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)点击Star,关注项目,即可随时了解X2Paddle的最新进展。 \n", "本教程用于帮助用户学习将TensorFlow训练后的预测模型迁移至PaddlePaddle框架,以TensorFlow版本的[MobileNetV1](https://github.com/tensorflow/models/tree/master/research/slim)为例进行详细介绍。 \n", "\n", "## 安装及准备\n", "### 1. 安装X2Paddle\n", "***方式一:(推荐)***" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! git clone https://github.com/PaddlePaddle/X2Paddle.git\n", "! cd X2Paddle\n", "! git checkout develop\n", "! python setup.py install" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***方式二:***" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install x2paddle==1.0.1 --index https://pypi.Python.org/simple/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 安装TensorFlow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install tensorflow==1.14.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 安装PaddlePaddle" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! pip install paddlepaddle==2.0.1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型迁移\n", "### 1. 获取MobileNetV1的FrozenModel\n", "由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! wget http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz\n", "! tar zxvf mobilenet_v1_0.25_128.tgz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 转换\n", "需要传入的参数如下:\n", "> --framework (-f):源模型类型,此处设置为tensorflow。 \n", "> --save_dir (-s):指定转换后的模型保存目录路径。 \n", "> --model (-m):指定tensorflow的pb模型。 \n", "> --paddle_type (-pt):指定转换为动态图代码(dygraph)或者静态图代码(static),默认为dygraph。 \n", "\n", "***方式一:***生成静态图代码,并保存成静态图预测模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! x2paddle -f tensorflow -m ./mobilenet_v1_0.25_128_frozen.pb -s pd_model_static -pt static" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***方式二:***生成动态图代码,并保存成静态图预测模型\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! x2paddle -f tensorflow -m ./mobilenet_v1_0.25_128_frozen.pb -s pd_model_dygraph -pt dygraph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用方式一转换得到的PaddlePaddle预测模型进行预测: \n", "(1)下载ImageNet类别文件" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(2)预测" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 构造输入\n", "import cv2\n", "import numpy as np\n", "img = cv2.imread(\"dog_tf.png\").astype(\"float32\") / 255.0\n", "img = np.expand_dims(img, 0)\n", "img -= 0.5\n", "img *= 2.0\n", " \n", "# 进行预测\n", "import paddle\n", "import numpy as np\n", "paddle.enable_static()\n", "exe = paddle.static.Executor(paddle.CPUPlace())\n", "[prog, inputs, outputs] = paddle.static.load_inference_model(path_prefix=\"pd_model_static/inference_model\", \n", " executor=exe, \n", " model_filename=\"model.pdmodel\",\n", " params_filename=\"model.pdiparams\")\n", "result = exe.run(prog, feed={inputs[0]: img}, fetch_list=outputs)\n", "max_index = np.argmax(result[0])\n", "with open('imagenet_classes.txt') as f:\n", " classes = [line.strip() for line in f.readlines()]\n", "print(\"The category of dog.jpg is: {}\".format(classes[max_index]))" ] } ], "metadata": { "kernelspec": { "display_name": "all", "language": "python", "name": "all" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }