相关空间: https://huggingface.co/spaces/nateraw/quickdraw 标签:SKETCHPAD、LABELS、LIVE
算法能猜出你在画什么? 几年前,谷歌发布了Quick Draw数据集,其中包含人类绘制的各种物体的图画。 研究人员使用这个数据集来训练模型来猜测 Pictionary 风格的图画。
How well can an algorithm guess what you're drawing? A few years ago, Google released the Quick Draw dataset, which contains drawings made by humans of a variety of every objects. Researchers have used this dataset to train models to guess Pictionary-style drawings.
此类模型非常适合与 Gradio 的画板输入一起使用,因此在本教程中,我们将使用 Gradio 构建一个 Pictionary Web 应用程序。 我们将能够在 Python 中构建整个 Web 应用程序,并且看起来像这样(尝试画一些东西!):
Such models are perfect to use with Gradio's sketchpad input, so in this tutorial we will build a Pictionary web application using Gradio. We will be able to build the whole web application in Python, and will look like this (try drawing something!):
让我们开始吧! 本指南介绍了如何构建一个图画应用程序(分步):
Let's get started! This guide covers how to build a pictionary app (step-by-step):
确保你已经安装了gradio Python 包。 要使用预训练的画板模型,还要安装 torch 。
Make sure you have the gradio Python package already installed. To use the pretrained sketchpad model, also install torch.
首先,你需要一个草图识别模型。 由于许多研究人员已经在 Quick Draw 数据集上训练了自己的模型,因此我们将在本教程中使用预训练模型。 我们的模型是由 Nate Raw 训练的一个 1.5 MB 的轻型模型,你可以在此处下载。
First, you will need a sketch recognition model. Since many researchers have already trained their own models on the Quick Draw dataset, we will use a pretrained model in this tutorial. Our model is a light 1.5 MB model trained by Nate Raw, that you can download here.
如果你有兴趣,这里是用于训练模型的代码。 我们将简单地在 PyTorch 中加载预训练模型,如下所示:
If you are interested, here is the code that was used to train the model. We will simply load the pretrained model in PyTorch, as follows:
import torch
from torch import nn
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
predict 函数接下来,你需要定义一个函数来接收用户输入(在本例中为素描图像)并返回预测结果。 预测应作为字典返回,其键是类名,值是置信概率。 我们将从这个文本文件加载类名。
Next, you will need to define a function that takes in the user input, which in this case is a sketched image, and returns the prediction. The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities. We will load the class names from this text file.
对于我们的预训练模型,它将如下所示:
In the case of our pretrained model, it will look like this:
from pathlib import Path
LABELS = Path('class_names.txt').read_text().splitlines()
def predict(img):
x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
out = model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
confidences = {LABELS[i]: v.item() for i, v in zip(indices, values)}
return confidences
让我们分解一下。 该函数采用一个参数:
Let's break this down. The function takes one parameters:
img :作为 numpy 数组的输入图像
img: the input image as a numpy array
然后,该函数将图像转换为 PyTorch tensor ,将其传递给模型,并返回:
Then, the function converts the image to a PyTorch tensor, passes it through the model, and returns:
confidences :前五个预测,作为一个字典,其键是类标签,其值是置信概率
confidences: the top five predictions, as a dictionary whose keys are class labels and whose values are confidence probabilities
现在我们已经设置了预测功能,我们可以围绕它创建一个渐变界面。
Now that we have our predictive function set up, we can create a Gradio Interface around it.
在这种情况下,输入组件是一个画板。 要创建画板输入,我们可以使用方便的字符串快捷方式 "sketchpad" ,它会创建一个画布供用户在其上绘制并处理预处理以将其转换为 numpy 数组。
In this case, the input component is a sketchpad. To create a sketchpad input, we can use the convenient string shortcut, "sketchpad" which creates a canvas for a user to draw on and handles the preprocessing to convert that to a numpy array.
输出组件将是一个 "label" ,它以漂亮的形式显示顶部标签。
The output component will be a "label", which displays the top labels in a nice form.
最后,我们将再添加一个参数,设置 live=True ,它允许我们的界面实时运行,每次用户在画板上绘图时调整其预测。 Gradio 的代码如下所示:
Finally, we'll add one more parameter, setting live=True, which allows our interface to run in real time, adjusting its predictions every time a user draws on the sketchpad. The code for Gradio looks like this:
import gradio as gr
gr.Interface(fn=predict,
inputs="sketchpad",
outputs="label",
live=True).launch()
这将产生以下界面,你可以在浏览器中尝试(尝试绘制一些东西,例如“蛇”或“笔记本电脑”):
This produces the following interface, which you can try right here in your browser (try drawing something, like a "snake" or a "laptop"):
你完成了! 这就是构建 Pictionary 风格的猜谜应用程序所需的全部代码。 玩得开心并尝试找到一些边缘案例🧐
And you're done! That's all the code you need to build a Pictionary-style guessing app. Have fun and try to find some edge cases 🧐