# 通过使用 Flask 的 REST API 在 Python 中部署 PyTorch > 原文: **作者**: [Avinash Sajjanshetty](https://avi.im) 在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开用于模型推理的 REST API。 特别是,我们将部署预训练的 DenseNet 121 模型来检测图像。 小费 此处使用的所有代码均以 MIT 许可发布,可在 [Github](https://github.com/avinassh/pytorch-flask-api) 上找到。 这是在生产中部署 PyTorch 模型的系列教程中的第一篇。 到目前为止,以这种方式使用 Flask 是开始为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的用例。 为了那个原因: > * 如果您已经熟悉 TorchScript,则可以直接进入我们的[通过 C++ 加载 TorchScript 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)的教程。 > * 如果您首先需要在 TorchScript 上进行复习,请查看我们的 [TorchScript 入门](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)教程。 ## API 定义 我们将首先定义 API 端点,请求和响应类型。 我们的 API 端点将位于`/predict`,它通过包含图片的`file`参数接受 HTTP POST 请求。 响应将是包含预测的 JSON 响应: ```py {"class_id": "n02124075", "class_name": "Egyptian_cat"} ``` ## 依赖项 通过运行以下命令来安装所需的依赖项: ```py $ pip install Flask==1.0.3 torchvision-0.3.0 ``` ## 简单的 Web 服务器 以下是一个简单的网络服务器,摘自 Flask 的文档 ```py from flask import Flask app = Flask(__name__) @app.route('/') def hello(): return 'Hello World!' ``` 将以上代码段保存在名为`app.py`的文件中,您现在可以通过输入以下内容来运行 Flask 开发服务器: ```py $ FLASK_ENV=development FLASK_APP=app.py flask run ``` 当您在网络浏览器中访问`http://localhost:5000/`时,您会看到`Hello World!`文字 我们将对上面的代码片段进行一些更改,以使其适合我们的 API 定义。 首先,我们将方法重命名为`predict`。 我们将端点路径更新为`/predict`。 由于图像文件将通过 HTTP POST 请求发送,因此我们将对其进行更新,使其也仅接受 POST 请求: ```py @app.route('/predict', methods=['POST']) def predict(): return 'Hello World!' ``` 我们还将更改响应类型,以使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新后的`app.py`文件现在为: ```py from flask import Flask, jsonify app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'}) ``` ## 推断 在下一部分中,我们将重点介绍编写推理代码。 这将涉及两部分,第一部分是准备图像,以便可以将其馈送到 DenseNet;第二部分,我们将编写代码以从模型中获取实际的预测。 ### 准备图像 DenseNet 模型要求图像为尺寸为`224 x 224`的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。 您可以在上阅读有关它的更多信息。 我们将使用`torchvision`库中的`transforms`并建立一个转换管道,该转换管道可根据需要转换图像。 [您可以这里阅读有关转换的更多信息](https://pytorch.org/docs/stable/torchvision/transforms.html)。 ```py import io import torchvision.transforms as transforms from PIL import Image def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) ``` 上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。 要测试上述方法,请以字节模式读取图像文件(首先将`../_static/img/sample_file.jpeg`替换为计算机上文件的实际路径),然后查看是否取回张量: ```py with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() tensor = transform_image(image_bytes=image_bytes) print(tensor) ``` 出: ```py tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473], [ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302], [ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644], ..., [ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667], [ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468], [ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]], [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529], [ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354], [ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704], ..., [ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508], [ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283], [ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]], [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476], [ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650], [ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999], ..., [ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346], [ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125], [ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]]) ``` ### 预测 现在将使用预训练的 DenseNet 121 模型来预测图像类别。 我们将使用`torchvision`库中的一个,加载模型并进行推断。 在此示例中,我们将使用预训练模型,但您可以对自己的模型使用相同的方法。 在此[教程](../beginner/saving_loading_models.html)中查看有关加载模型的更多信息。 ```py from torchvision import models # Make sure to pass `pretrained` as `True` to use the pretrained weights: model = models.densenet121(pretrained=True) # Since we are using our model only for inference, switch to `eval` mode: model.eval() def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) return y_hat ``` 张量`y_hat`将包含预测的类 ID 的索引。 但是,我们需要一个人类可读的类名。 为此,我们需要一个类 ID 来进行名称映射。 将[这个文件](https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json)下载为`imagenet_class_index.json`,并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在`tutorials/_static`中)。 此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。 ```py import json imagenet_class_index = json.load(open('../_static/imagenet_class_index.json')) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] ``` 在使用`imagenet_class_index`字典之前,首先我们将张量值转换为字符串值,因为`imagenet_class_index`字典中的键是字符串。 我们将测试上述方法: ```py with open("../_static/img/sample_file.jpeg", 'rb') as f: image_bytes = f.read() print(get_prediction(image_bytes=image_bytes)) ``` 出: ```py ['n02124075', 'Egyptian_cat'] ``` 您应该得到如下响应: ```py ['n02124075', 'Egyptian_cat'] ``` 数组中的第一项是 ImageNet 类 ID,第二项是人类可读的名称。 注意 您是否注意到`model`变量不属于`get_prediction`方法? 还是为什么模型是全局变量? 就内存和计算而言,加载模型可能是一项昂贵的操作。 如果我们以`get_prediction`方法加载模型,则每次调用该方法时都会不必要地加载该模型。 由于我们正在构建一个 Web 服务器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。 因此,我们仅将模型加载到内存中一次。 在生产系统中,必须高效使用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。 ## 将模型集成到我们的 API 服务器中 在最后一部分中,我们将模型添加到 Flask API 服务器中。 由于我们的 API 服务器应该获取图像文件,因此我们将更新`predict`方法以从请求中读取文件: ```py from flask import request @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # we will get the file from the request file = request.files['file'] # convert that to bytes img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) ``` `app.py`文件现在完成。 以下是完整版本; 将路径替换为保存文件的路径,它应运行: ```py import io import json from torchvision import models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) imagenet_class_index = json.load(open('/imagenet_class_index.json')) model = models.densenet121(pretrained=True) model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run() ``` 让我们测试一下我们的网络服务器! 跑: ```py $ FLASK_ENV=development FLASK_APP=app.py flask run ``` 我们可以使用[`requests`](https://pypi.org/project/requests/)库向我们的应用发送 POST 请求: ```py import requests resp = requests.post("http://localhost:5000/predict", files={"file": open('/cat.jpg','rb')}) ``` 现在打印`resp.json()`将显示以下内容: ```py {"class_id": "n02124075", "class_name": "Egyptian_cat"} ``` ## 后续步骤 我们编写的服务器非常琐碎,可能无法完成生产应用所需的一切。 因此,您可以采取一些措施来改善它: * 端点`/predict`假定请求中始终会有一个图像文件。 这可能并不适用于所有请求。 我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。 * 用户也可以发送非图像类型的文件。 由于我们没有处理错误,因此这将破坏我们的服务器。 添加显式的错误处理路径将引发异常,这将使我们能够更好地处理错误的输入 * 即使模型可以识别大量类别的图像,也可能无法识别所有图像。 增强实现以处理模型无法识别图像中的任何情况的情况。 * 我们在开发模式下运行 Flask 服务器,该服务器不适合在生产中进行部署。 您可以查看[本教程](https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/),以便在生产环境中部署 Flask 服务器。 * 您还可以通过创建一个带有表单的页面来添加 UI,该表单可以拍摄图像并显示预测。 查看类似项目的[演示](https://pytorch-imagenet.herokuapp.com/)及其[源代码](https://github.com/avinassh/pytorch-flask-api-heroku)。 * 在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。 我们可以修改服务以能够一次返回多个图像的预测。 此外,[service-streamer](https://github.com/ShannonAI/service-streamer) 库自动将对服务的请求排队,并将请求采样到微型批量中,这些微型批量可输入模型中。 您可以查看[本教程](https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer)。 * 最后,我们鼓励您在页面顶部查看链接到的其他 PyTorch 模型部署教程。 **脚本的总运行时间**:(0 分钟 1.232 秒) [下载 Python 源码:`flask_rest_api_tutorial.py`](../_downloads/146c514e84d7e33f2a302bcc3ae793cb/flask_rest_api_tutorial.py) [下载 Jupyter 笔记本:`flask_rest_api_tutorial.ipynb`](../_downloads/6c042f3d39855d2a2de414758e5f9836/flask_rest_api_tutorial.ipynb) [由 Sphinx 画廊](https://sphinx-gallery.readthedocs.io)生成的画廊