使用Flask部署¶
创建日期:2020年5月4日 | 最后更新:2021年9月15日 | 最后验证:未验证
在本教程中,您将学习:
如何将训练好的PyTorch模型封装在Flask容器中,以通过Web API暴露它
如何将传入的网络请求转换为PyTorch张量以供您的模型使用
如何为HTTP响应打包模型的输出
需求¶
您需要一个安装了以下包(及其依赖项)的Python 3环境:
PyTorch 1.5
TorchVision 0.6.0
Flask 1.1
可选地,要获取一些支持文件,您需要git。
安装PyTorch和TorchVision的说明可在pytorch.org找到。安装Flask的说明可在the Flask site找到。
什么是Flask?¶
Flask 是一个用 Python 编写的轻量级 Web 服务器。它为您提供了一种便捷的方式,可以快速设置一个 Web API,用于从您训练好的 PyTorch 模型中进行预测,无论是直接使用,还是作为更大系统中的一个 Web 服务。
设置和支持文件¶
我们将创建一个网络服务,该服务接收图像并将其映射到ImageNet数据集的1000个类别之一。为此,您需要一个图像文件进行测试。您还可以选择获取一个文件,该文件将模型输出的类别索引映射为人类可读的类别名称。
选项1:快速获取两个文件¶
你可以通过检出TorchServe仓库并将它们复制到你的工作文件夹中快速获取这两个支持文件。(注意:本教程不依赖于TorchServe——这只是获取文件的快速方法。)在你的shell提示符下发出以下命令:
git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .
你已经得到了它们!
选项2:使用您自己的图像¶
在下面的Flask服务中,index_to_name.json 文件是可选的。
你可以使用自己的图像测试你的服务 - 只需确保它是一个
3色JPEG。
构建您的Flask服务¶
完整的Flask服务的Python脚本在本食谱的末尾展示;你可以将其复制并粘贴到你自己的app.py文件中。下面我们将查看各个部分以明确它们的功能。
导入¶
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
按顺序:
我们将使用来自
torchvision.models的预训练DenseNet模型torchvision.transforms包含用于处理图像数据的工具Pillow (
PIL) 是我们最初用来加载图像文件的工具当然,我们需要来自
flask的类
预处理¶
def transform_image(infile):
input_transforms = [transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile)
timg = my_transforms(image)
timg.unsqueeze_(0)
return timg
网络请求给了我们一个图像文件,但我们的模型期望的是一个形状为 (N, 3, 224, 224) 的 PyTorch 张量,其中 N 是输入批次中的项目数量。(我们将只有一个批次大小为 1。)我们做的第一件事是组合一组 TorchVision 变换,这些变换会调整大小并裁剪图像,将其转换为张量,然后对张量中的值进行归一化。(有关此归一化的更多信息,请参阅 torchvision.models_ 的文档。)
之后,我们打开文件并应用转换。转换返回一个形状为 (3, 224, 224) 的张量 - 这是一个 224x224 图像的 3 个颜色通道。因为我们需要将这个单一图像变成一个批次,所以我们使用 unsqueeze_(0) 调用来通过添加一个新的第一维度来就地修改张量。张量包含相同的数据,但现在形状为 (1, 3, 224, 224)。
一般来说,即使你不处理图像数据,你也需要将来自HTTP请求的输入转换为PyTorch可以使用的张量。
推理¶
def get_prediction(input_tensor):
outputs = model.forward(input_tensor)
_, y_hat = outputs.max(1)
prediction = y_hat.item()
return prediction
推理本身是最简单的部分:当我们把输入张量传递给模型时,我们会得到一个表示模型估计图像属于特定类别的可能性的张量。max()调用找到具有最大可能性值的类别,并返回该值以及ImageNet类别索引。最后,我们通过item()调用从包含它的张量中提取该类别索引,并返回它。
后处理¶
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
render_prediction() 方法将预测的类别索引映射为人类可读的类别标签。通常,在从模型获得预测后,需要进行后处理以使预测结果适合人类阅读或供其他软件使用。
运行完整的Flask应用程序¶
将以下内容粘贴到名为 app.py 的文件中:
import io
import json
import os
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
model = models.densenet121(pretrained=True) # Trained on 1000 classes from ImageNet
model.eval() # Turns off autograd
img_class_map = None
mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
with open (mapping_file_path) as f:
img_class_map = json.load(f)
# Transform input into the form our model expects
def transform_image(infile):
input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input
[0.229, 0.224, 0.225])]
my_transforms = transforms.Compose(input_transforms)
image = Image.open(infile) # Open the image file
timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor
timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1
return timg
# Get a prediction
def get_prediction(input_tensor):
outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes
_, y_hat = outputs.max(1) # Extract the most likely class
prediction = y_hat.item() # Extract the int value from the PyTorch tensor
return prediction
# Make the prediction human-readable
def render_prediction(prediction_idx):
stridx = str(prediction_idx)
class_name = 'Unknown'
if img_class_map is not None:
if stridx in img_class_map is not None:
class_name = img_class_map[stridx][1]
return prediction_idx, class_name
@app.route('/', methods=['GET'])
def root():
return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
if file is not None:
input_tensor = transform_image(file)
prediction_idx = get_prediction(input_tensor)
class_id, class_name = render_prediction(prediction_idx)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
要从您的 shell 提示符启动服务器,请发出以下命令:
FLASK_APP=app.py flask run
默认情况下,您的Flask服务器正在监听端口5000。一旦服务器运行,打开另一个终端窗口,并测试您的新推理服务器:
curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"
如果一切设置正确,您应该会收到类似于以下的响应:
{"class_id":285,"class_name":"Egyptian_cat"}
重要资源¶
pytorch.org 用于安装说明,以及更多文档和教程