37.md 13.5 KB
Newer Older
W
wizardforcel 已提交
1
# 通过使用 Flask 的 REST API 在 Python 中部署 PyTorch
W
wizardforcel 已提交
2

W
wizardforcel 已提交
3
> 原文:<https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html>
W
wizardforcel 已提交
4 5 6 7 8 9 10 11 12 13 14

**作者**[Avinash Sajjanshetty](https://avi.im)

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开用于模型推理的 REST API。 特别是,我们将部署预训练的 DenseNet 121 模型来检测图像。

小费

此处使用的所有代码均以 MIT 许可发布,可在 [Github](https://github.com/avinassh/pytorch-flask-api) 上找到。

这是在生产中部署 PyTorch 模型的系列教程中的第一篇。 到目前为止,以这种方式使用 Flask 是开始为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的用例。 为了那个原因:

W
wizardforcel 已提交
15 16
> *   如果您已经熟悉 TorchScript,则可以直接进入我们的[通过 C++ 加载 TorchScript 模型](https://pytorch.org/tutorials/advanced/cpp_export.html)的教程。
> *   如果您首先需要在 TorchScript 上进行复习,请查看我们的 [TorchScript 入门](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)教程。
W
wizardforcel 已提交
17

W
wizardforcel 已提交
18
## API 定义
W
wizardforcel 已提交
19 20 21 22 23 24 25 26

我们将首先定义 API 端点,请求和响应类型。 我们的 API 端点将位于`/predict`,它通过包含图片的`file`参数接受 HTTP POST 请求。 响应将是包含预测的 JSON 响应:

```py
{"class_id": "n02124075", "class_name": "Egyptian_cat"}

```

W
wizardforcel 已提交
27
## 依赖项
W
wizardforcel 已提交
28 29 30 31 32 33 34 35

通过运行以下命令来安装所需的依赖项:

```py
$ pip install Flask==1.0.3 torchvision-0.3.0

```

W
wizardforcel 已提交
36
## 简单的 Web 服务器
W
wizardforcel 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

以下是一个简单的网络服务器,摘自 Flask 的文档

```py
from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello():
    return 'Hello World!'

```

将以上代码段保存在名为`app.py`的文件中,您现在可以通过输入以下内容来运行 Flask 开发服务器:

```py
$ FLASK_ENV=development FLASK_APP=app.py flask run

```

当您在网络浏览器中访问`http://localhost:5000/`时,您会看到`Hello World!`文字

我们将对上面的代码片段进行一些更改,以使其适合我们的 API 定义。 首先,我们将方法重命名为`predict`。 我们将端点路径更新为`/predict`。 由于图像文件将通过 HTTP POST 请求发送,因此我们将对其进行更新,使其也仅接受 POST 请求:

```py
@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

```

我们还将更改响应类型,以使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新后的`app.py`文件现在为:

```py
from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

```

W
wizardforcel 已提交
80
## 推断
W
wizardforcel 已提交
81 82 83

在下一部分中,我们将重点介绍编写推理代码。 这将涉及两部分,第一部分是准备图像,以便可以将其馈送到 DenseNet;第二部分,我们将编写代码以从模型中获取实际的预测。

W
wizardforcel 已提交
84
### 准备图像
W
wizardforcel 已提交
85

W
wizardforcel 已提交
86
DenseNet 模型要求图像为尺寸为`224 x 224`的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。 您可以在上阅读有关它的更多信息。
W
wizardforcel 已提交
87

W
wizardforcel 已提交
88
我们将使用`torchvision`库中的`transforms`并建立一个转换管道,该转换管道可根据需要转换图像。 [您可以这里阅读有关转换的更多信息](https://pytorch.org/docs/stable/torchvision/transforms.html)
W
wizardforcel 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

```py
import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

```

W
wizardforcel 已提交
108
上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。 要测试上述方法,请以字节模式读取图像文件(首先将`../_static/img/sample_file.jpeg`替换为计算机上文件的实际路径),然后查看是否取回张量:
W
wizardforcel 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

```py
with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

```

出:

```py
tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],
          [ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],
          [ 0.7077,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],
          ...,
          [ 1.3755,  1.3927,  1.4098,  ...,  1.1700,  1.3584,  1.6667],
          [ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],
          [ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],

         [[ 0.5728,  0.5378,  0.5203,  ..., -1.3704, -1.3529, -1.3529],
          [ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3179, -1.3354],
          [ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],
          ...,
          [ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],
          [ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],
          [ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],

         [[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6476],
          [ 0.8099,  0.7576,  0.7228,  ..., -1.6476, -1.6476, -1.6650],
          [ 1.0017,  0.9145,  0.8797,  ..., -1.6476, -1.6650, -1.6999],
          ...,
          [ 1.6291,  1.6291,  1.6465,  ...,  1.6291,  1.8208,  2.1346],
          [ 2.1868,  2.0300,  1.6814,  ...,  1.7685,  1.9428,  2.0125],
          [ 1.9254,  2.0997,  2.0823,  ...,  1.9428,  2.2043,  1.9080]]]])

```

W
wizardforcel 已提交
147
### 预测
W
wizardforcel 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

现在将使用预训练的 DenseNet 121 模型来预测图像类别。 我们将使用`torchvision`库中的一个,加载模型并进行推断。 在此示例中,我们将使用预训练模型,但您可以对自己的模型使用相同的方法。 在此[教程](../beginner/saving_loading_models.html)中查看有关加载模型的更多信息。

```py
from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

```

W
wizardforcel 已提交
167
张量`y_hat`将包含预测的类 ID 的索引。 但是,我们需要一个人类可读的类名。 为此,我们需要一个类 ID 来进行名称映射。 将[这个文件](https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json)下载为`imagenet_class_index.json`,并记住它的保存位置(或者,如果您按照本教程中的确切步骤操作,请将其保存在`tutorials/_static`中)。 此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。
W
wizardforcel 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211

```py
import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

```

在使用`imagenet_class_index`字典之前,首先我们将张量值转换为字符串值,因为`imagenet_class_index`字典中的键是字符串。 我们将测试上述方法:

```py
with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

```

出:

```py
['n02124075', 'Egyptian_cat']

```

您应该得到如下响应:

```py
['n02124075', 'Egyptian_cat']

```

数组中的第一项是 ImageNet 类 ID,第二项是人类可读的名称。

注意

您是否注意到`model`变量不属于`get_prediction`方法? 还是为什么模型是全局变量? 就内存和计算而言,加载模型可能是一项昂贵的操作。 如果我们以`get_prediction`方法加载模型,则每次调用该方法时都会不必要地加载该模型。 由于我们正在构建一个 Web 服务器,因此每秒可能有成千上万的请求,因此我们不应该浪费时间为每个推断重复加载模型。 因此,我们仅将模型加载到内存中一次。 在生产系统中,必须高效使用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。

W
wizardforcel 已提交
212
## 将模型集成到我们的 API 服务器中
W
wizardforcel 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

在最后一部分中,我们将模型添加到 Flask API 服务器中。 由于我们的 API 服务器应该获取图像文件,因此我们将更新`predict`方法以从请求中读取文件:

```py
from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

```

`app.py`文件现在完成。 以下是完整版本; 将路径替换为保存文件的路径,它应运行:

```py
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run()

```

让我们测试一下我们的网络服务器! 跑:

```py
$ FLASK_ENV=development FLASK_APP=app.py flask run

```

W
wizardforcel 已提交
284
我们可以使用[`requests`](https://pypi.org/project/requests/)库向我们的应用发送 POST 请求:
W
wizardforcel 已提交
285 286 287 288 289 290 291 292 293

```py
import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

```

W
wizardforcel 已提交
294
现在打印`resp.json()`将显示以下内容:
W
wizardforcel 已提交
295 296 297 298 299 300

```py
{"class_id": "n02124075", "class_name": "Egyptian_cat"}

```

W
wizardforcel 已提交
301
## 后续步骤
W
wizardforcel 已提交
302

W
wizardforcel 已提交
303
我们编写的服务器非常琐碎,可能无法完成生产应用所需的一切。 因此,您可以采取一些措施来改善它:
W
wizardforcel 已提交
304 305 306 307

*   端点`/predict`假定请求中始终会有一个图像文件。 这可能并不适用于所有请求。 我们的用户可能发送带有其他参数的图像,或者根本不发送任何图像。
*   用户也可以发送非图像类型的文件。 由于我们没有处理错误,因此这将破坏我们的服务器。 添加显式的错误处理路径将引发异常,这将使我们能够更好地处理错误的输入
*   即使模型可以识别大量类别的图像,也可能无法识别所有图像。 增强实现以处理模型无法识别图像中的任何情况的情况。
W
wizardforcel 已提交
308
*   我们在开发模式下运行 Flask 服务器,该服务器不适合在生产中进行部署。 您可以查看[本教程](https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/),以便在生产环境中部署 Flask 服务器。
W
wizardforcel 已提交
309
*   您还可以通过创建一个带有表单的页面来添加 UI,该表单可以拍摄图像并显示预测。 查看类似项目的[演示](https://pytorch-imagenet.herokuapp.com/)及其[源代码](https://github.com/avinassh/pytorch-flask-api-heroku)
W
wizardforcel 已提交
310
*   在本教程中,我们仅展示了如何构建可以一次返回单个图像预测的服务。 我们可以修改服务以能够一次返回多个图像的预测。 此外,[service-streamer](https://github.com/ShannonAI/service-streamer) 库自动将对服务的请求排队,并将请求采样到微型批量中,这些微型批量可输入模型中。 您可以查看[本教程](https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer)
W
wizardforcel 已提交
311 312
*   最后,我们鼓励您在页面顶部查看链接到的其他 PyTorch 模型部署教程。

W
wizardforcel 已提交
313
**脚本的总运行时间**:(0 分钟 1.232 秒)
W
wizardforcel 已提交
314

W
wizardforcel 已提交
315
[下载 Python 源码:`flask_rest_api_tutorial.py`](../_downloads/146c514e84d7e33f2a302bcc3ae793cb/flask_rest_api_tutorial.py)
W
wizardforcel 已提交
316

W
wizardforcel 已提交
317
[下载 Jupyter 笔记本:`flask_rest_api_tutorial.ipynb`](../_downloads/6c042f3d39855d2a2de414758e5f9836/flask_rest_api_tutorial.ipynb)
W
wizardforcel 已提交
318

W
wizardforcel 已提交
319
[由 Sphinx 画廊](https://sphinx-gallery.readthedocs.io)生成的画廊