1. 其他教程
  2. 使用标记

使用标记

介绍

当你演示一个机器学习模型时,你可能希望从尝试该模型的用户那里收集数据,特别是模型表现不如预期的数据点。捕获这些“困难”的数据点非常有价值,因为它允许你改进你的机器学习模型,使其更加可靠和稳健。

Gradio 通过在每个 Interface 中包含一个 Flag 按钮,简化了这些数据的收集。这使得用户或测试人员可以轻松地将数据发送回运行演示的机器。在本指南中,我们将详细讨论如何使用标记功能,无论是在 gradio.Interface 还是 gradio.Blocks 中。

gradio.Interface中的Flag按钮

使用Gradio的Interface进行标记特别容易。默认情况下,在输出组件下方有一个标记为Flag的按钮。当用户测试您的模型时,如果看到有趣的输出,他们可以点击标记按钮将输入和输出数据发送回运行演示的机器。样本会保存到一个CSV日志文件中(默认情况下)。如果演示涉及图像、音频、视频或其他类型的文件,这些文件会单独保存在一个并行目录中,并且这些文件的路径会保存在CSV文件中。

gradio.Interface中有四个参数控制标记的工作方式。我们将更详细地介绍它们。

  • flagging_mode: 此参数可以设置为"manual"(默认)、"auto""never"
    • manual: 用户将看到一个标记按钮,只有在点击按钮时才会标记样本。
    • auto: 用户不会看到标记按钮,但每个样本都会自动标记。
    • never: 用户不会看到标记按钮,也不会标记任何样本。
  • flagging_options: 此参数可以是 None(默认)或字符串列表。
    • 如果为 None,则用户只需点击标记按钮,不会显示其他选项。
    • 如果提供了字符串列表,则用户会看到多个按钮,每个按钮对应提供的字符串。例如,如果此参数的值为 ["Incorrect", "Ambiguous"],则会出现标记为标记为不正确标记为模糊的按钮。这仅在 flagging_mode"manual" 时适用。
    • 然后,所选选项将与输入和输出一起记录。
  • flagging_dir: 此参数接受一个字符串。
    • 它表示用于存储标记数据的目录的名称。
  • flagging_callback: 此参数接受FlaggingCallback类的一个子类的实例
    • 使用此参数可以编写自定义代码,当点击标记按钮时运行
    • 默认情况下,此参数设置为gr.JSONLogger的实例

标记的数据会发生什么?

在由flagging_dir参数提供的目录中,一个JSON文件将记录被标记的数据。

这是一个例子:下面的代码创建了嵌入在下面的计算器界面:

import gradio as gr


def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        return num1 / num2


iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    allow_flagging="manual"
)

iface.launch()

当你点击上面的标记按钮时,启动接口的目录将包含一个新的标记子文件夹,里面有一个csv文件。这个csv文件包含所有被标记的数据。

+-- flagged/
|   +-- logs.csv

已标记/日志.csv

num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542

如果接口涉及文件数据,例如图像和音频组件,将创建文件夹来存储那些标记的数据。例如,一个从image输入到image输出的接口将创建以下结构。

+-- flagged/
|   +-- logs.csv
|   +-- image/
|   |   +-- 0.png
|   |   +-- 1.png
|   +-- Output/
|   |   +-- 0.png
|   |   +-- 1.png

flagged/logs.csv

im,Output timestamp
im/0.png,Output/0.png,2022-02-04 19:49:58.026963
im/1.png,Output/1.png,2022-02-02 10:40:51.093412

如果您希望用户提供标记的原因,可以将字符串列表传递给Interface的flagging_options参数。用户在标记时必须选择其中一个选项,该选项将作为附加列保存到CSV中。

如果我们回到计算器示例,以下代码将创建嵌入在下面的界面。

iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=["wrong sign", "off by one", "other"]
)

iface.launch()

当用户点击标记按钮时,csv文件现在将包含一个指示所选选项的列。

已标记/日志.csv

num1,operation,num2,Output,flag,timestamp
5,add,7,-12,wrong sign,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,off by one,2022-02-04 11:42:32.062512

使用块进行标记

如果你正在使用gradio.Blocks呢?一方面,使用Blocks你会有更多的灵活性——你可以编写任何你想在按钮点击时运行的Python代码,并使用Blocks中的内置事件来分配它。

同时,您可能希望使用现有的FlaggingCallback以避免编写额外的代码。 这需要两个步骤:

  1. 你必须在代码中某个地方运行你的回调函数的.setup(),在你第一次标记数据之前
  2. 当点击标记按钮时,您将触发回调的.flag()方法,确保正确收集参数并禁用典型的预处理。

这是一个带有图像棕褐色滤镜的示例块演示,允许您使用默认的CSVLogger标记数据:

import numpy as np
import gradio as gr

def sepia(input_img, strength):
    sepia_filter = strength * np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    ) + (1-strength) * np.identity(3)
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

callback = gr.CSVLogger()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            img_input = gr.Image()
            strength = gr.Slider(0, 1, 0.5)
        img_output = gr.Image()
    with gr.Row():
        btn = gr.Button("Flag")

    # This needs to be called at some point prior to the first call to callback.flag()
    callback.setup([img_input, strength, img_output], "flagged_data_points")

    img_input.change(sepia, [img_input, strength], img_output)
    strength.change(sepia, [img_input, strength], img_output)

    # We can choose which components to flag -- in this case, we'll flag all of them
    btn.click(lambda *args: callback.flag(list(args)), [img_input, strength, img_output], None, preprocess=False)

demo.launch()

隐私

重要提示:请确保您的用户了解他们提交的数据何时被保存,以及您计划如何处理这些数据。当您使用flagging_mode=auto(当通过演示提交的所有数据都被标记时)时,这一点尤为重要。

就这样!祝您构建愉快 :)