图像分类是计算机视觉中的核心任务。构建更好的分类器来识别图片中的物体是一个活跃的研究领域,因为它从自动驾驶汽车到医学成像都有广泛的应用。
这样的模型非常适合与Gradio的图像输入组件一起使用,因此在本教程中,我们将构建一个使用Gradio进行图像分类的Web演示。我们将能够在Python中构建整个Web应用程序,它将看起来像页面底部的演示。
让我们开始吧!
确保你已经安装了gradio
Python包。我们将使用一个预训练的图片分类模型,因此你也应该安装torch
。
首先,我们需要一个图像分类模型。在本教程中,我们将使用预训练的Resnet-18模型,因为它可以从PyTorch Hub轻松下载。您可以使用不同的预训练模型或训练自己的模型。
import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
因为我们将使用模型进行推理,所以我们调用了.eval()
方法。
predict
函数接下来,我们需要定义一个函数,该函数接收用户输入,在这种情况下是一张图片,并返回预测结果。预测结果应作为字典返回,其键为类名,值为置信概率。我们将从这个文本文件中加载类名。
在我们的预训练模型的情况下,它看起来像这样:
import requests
from PIL import Image
from torchvision import transforms
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
让我们分解一下。该函数接受一个参数:
inp
: 输入图像作为 PIL
图像然后,函数将图像转换为PIL图像,最终转换为PyTorch tensor
,通过模型传递,并返回:
confidences
: 预测结果,作为一个字典,其键是类别标签,值是置信概率现在我们已经设置好了预测函数,我们可以围绕它创建一个Gradio界面。
在这种情况下,输入组件是一个拖放图像组件。要创建此输入,我们使用Image(type="pil")
,它创建组件并处理预处理以将其转换为PIL
图像。
输出组件将是一个Label
,它以美观的形式显示顶部标签。由于我们不想显示所有1,000个类别标签,我们将通过将其构造为Label(num_top_classes=3)
来自定义它,以仅显示前3个图像。
最后,我们将再添加一个参数,examples
,它允许我们用一些预定义的示例预填充我们的接口。Gradio 的代码如下所示:
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=["lion.jpg", "cheetah.jpg"]).launch()
这将生成以下界面,您可以直接在浏览器中尝试(尝试上传您自己的示例!):
你已经完成了!这就是你构建图像分类器网络演示所需的所有代码。如果你想与他人分享,尝试在launch()
接口时设置share=True
!