PyTorch 中的图像分类

相关空间: https://huggingface.co/spaces/abidlabs/pytorch-image-classifier,https : //huggingface.co/spaces/pytorch/ResNet,https : //huggingface.co/spaces/pytorch/ResNext,https ://huggingface.co/spaces/pytorch/SqueezeNet标签:VISION,RESNET,PYTORCH

介绍

图像分类是计算机视觉的核心任务。 构建更好的分类器来对图片中存在的对象进行分类是一个活跃的研究领域,因为它的应用范围从自动驾驶汽车到医学成像。

Image classification is a central task in computer vision. Building better classifiers to classify what object is present in a picture is an active area of research, as it has applications stretching from autonomous vehicles to medical imaging.

此类模型非常适合与 Gradio 的图像输入组件一起使用,因此在本教程中,我们将构建一个网络演示来使用 Gradio 对图像进行分类。 我们将能够用 Python 构建整个 Web 应用程序,它看起来像这样(尝试其中一个示例!):

Such models are perfect to use with Gradio's image input component, so in this tutorial we will build a web demo to classify images using Gradio. We will be able to build the whole web application in Python, and it will look like this (try one of the examples!):

让我们开始吧!

Let's get started!

先决条件

Prerequisites

确保你已经安装了gradio Python 包。 我们将使用预训练的图像分类模型,因此你还应该安装 torch

Make sure you have the gradio Python package already installed. We will be using a pretrained image classification model, so you should also have torch installed.

第一步——建立图像分类模型

首先,我们需要一个图像分类模型。 对于本教程,我们将使用预训练的 Resnet-18 模型,因为它可以很容易地从PyTorch Hub下载。 你可以使用不同的预训练模型或训练你自己的模型。

First, we will need an image classification model. For this tutorial, we will use a pretrained Resnet-18 model, as it is easily downloadable from PyTorch Hub. You can use a different pretrained model or train your own.

import torch

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()

因为我们将使用模型进行推理,所以我们调用了 .eval() 方法。

Because we will be using the model for inference, we have called the .eval() method.

第 2 步 — 定义 predict 函数

接下来,我们需要定义一个函数来接收用户输入(在本例中为图像)并返回预测。 预测应作为字典返回,其键是类名,值是置信概率。 我们将从这个文本文件加载类名。

Next, we will need to define a function that takes in the user input, which in this case is an image, and returns the prediction. The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities. We will load the class names from this text file.

对于我们的预训练模型,它将如下所示:

In the case of our pretrained model, it will look like this:

import requests
from PIL import Image
from torchvision import transforms

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def predict(inp):
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
  return confidences

让我们分解一下。 该函数接受一个参数:

Let's break this down. The function takes one parameter:

  • inp :作为 PIL 图像的输入图像

    inp: the input image as a PIL image

然后,该函数将图像转换为 PIL 图像,然后最终转换为 PyTorch tensor ,将其传递给模型,并返回:

Then, the function converts the image to a PIL Image and then eventually a PyTorch tensor, passes it through the model, and returns:

  • confidences :预测,作为字典,其键是类标签,其值是置信概率

    confidences: the predictions, as a dictionary whose keys are class labels and whose values are confidence probabilities

第 3 步 — 创建渐变界面

现在我们已经设置了预测功能,我们可以围绕它创建一个渐变界面。

Now that we have our predictive function set up, we can create a Gradio Interface around it.

在这种情况下,输入组件是一个拖放图像组件。 要创建此输入,我们使用 Image(type="pil") 创建组件并处理预处理以将其转换为 PIL 图像。

In this case, the input component is a drag-and-drop image component. To create this input, we use Image(type="pil") which creates the component and handles the preprocessing to convert that to a PIL image.

输出组件将是一个 Label ,它以漂亮的形式显示顶部标签。 由于我们不想显示所有 1,000 个类标签,我们将通过将其构造为 Label(num_top_classes=3) 来自定义它以仅显示前 3 个图像。

The output component will be a Label, which displays the top labels in a nice form. Since we don't want to show all 1,000 class labels, we will customize it to show only the top 3 images by constructing it as Label(num_top_classes=3).

最后,我们将再添加一个参数 examples ,它允许我们使用一些预定义的示例预填充我们的界面。 Gradio 的代码如下所示:

Finally, we'll add one more parameter, the examples, which allows us to prepopulate our interfaces with a few predefined examples. The code for Gradio looks like this:

import gradio as gr

gr.Interface(fn=predict, 
             inputs=gr.Image(type="pil"),
             outputs=gr.Label(num_top_classes=3),
             examples=["lion.jpg", "cheetah.jpg"]).launch()

这会产生以下界面,你可以在浏览器中尝试(尝试上传你自己的示例!):

This produces the following interface, which you can try right here in your browser (try uploading your own examples!):


你完成了! 这就是为图像分类器构建 Web 演示所需的全部代码。 如果你想与其他人分享,请尝试在 launch() 界面时设置 share=True

And you're done! That's all the code you need to build a web demo for an image classifier. If you'd like to share with others, try setting share=True when you launch() the Interface!