当你演示一个机器学习模型时,你可能希望从尝试该模型的用户那里收集数据,特别是模型表现不如预期的数据点。捕获这些“困难”的数据点非常有价值,因为它允许你改进你的机器学习模型,使其更加可靠和稳健。
Gradio 通过在每个 Interface 中包含一个 Flag 按钮,简化了这些数据的收集。这使得用户或测试人员可以轻松地将数据发送回运行演示的机器。在本指南中,我们将详细讨论如何使用标记功能,无论是在 gradio.Interface 还是 gradio.Blocks 中。
gradio.Interface中的Flag按钮使用Gradio的Interface进行标记特别容易。默认情况下,在输出组件下方有一个标记为Flag的按钮。当用户测试您的模型时,如果看到有趣的输出,他们可以点击标记按钮将输入和输出数据发送回运行演示的机器。样本会保存到一个CSV日志文件中(默认情况下)。如果演示涉及图像、音频、视频或其他类型的文件,这些文件会单独保存在一个并行目录中,并且这些文件的路径会保存在CSV文件中。
在gradio.Interface中有四个参数控制标记的工作方式。我们将更详细地介绍它们。
flagging_mode: 此参数可以设置为"manual"(默认)、"auto"或"never"。manual: 用户将看到一个标记按钮,只有在点击按钮时才会标记样本。auto: 用户不会看到标记按钮,但每个样本都会自动标记。never: 用户不会看到标记按钮,也不会标记任何样本。flagging_options: 此参数可以是 None(默认)或字符串列表。None,则用户只需点击标记按钮,不会显示其他选项。["Incorrect", "Ambiguous"],则会出现标记为标记为不正确和标记为模糊的按钮。这仅在 flagging_mode 为 "manual" 时适用。flagging_dir: 此参数接受一个字符串。flagging_callback: 此参数接受FlaggingCallback类的一个子类的实例gr.JSONLogger的实例在由flagging_dir参数提供的目录中,一个JSON文件将记录被标记的数据。
这是一个例子:下面的代码创建了嵌入在下面的计算器界面:
import gradio as gr
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
return num1 / num2
iface = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
allow_flagging="manual"
)
iface.launch()当你点击上面的标记按钮时,启动接口的目录将包含一个新的标记子文件夹,里面有一个csv文件。这个csv文件包含所有被标记的数据。
+-- flagged/
| +-- logs.csv已标记/日志.csv
num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542如果接口涉及文件数据,例如图像和音频组件,将创建文件夹来存储那些标记的数据。例如,一个从image输入到image输出的接口将创建以下结构。
+-- flagged/
| +-- logs.csv
| +-- image/
| | +-- 0.png
| | +-- 1.png
| +-- Output/
| | +-- 0.png
| | +-- 1.pngflagged/logs.csv
im,Output timestamp
im/0.png,Output/0.png,2022-02-04 19:49:58.026963
im/1.png,Output/1.png,2022-02-02 10:40:51.093412如果您希望用户提供标记的原因,可以将字符串列表传递给Interface的flagging_options参数。用户在标记时必须选择其中一个选项,该选项将作为附加列保存到CSV中。
如果我们回到计算器示例,以下代码将创建嵌入在下面的界面。
iface = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
flagging_mode="manual",
flagging_options=["wrong sign", "off by one", "other"]
)
iface.launch()当用户点击标记按钮时,csv文件现在将包含一个指示所选选项的列。
已标记/日志.csv
num1,operation,num2,Output,flag,timestamp
5,add,7,-12,wrong sign,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,off by one,2022-02-04 11:42:32.062512如果你正在使用gradio.Blocks呢?一方面,使用Blocks你会有更多的灵活性——你可以编写任何你想在按钮点击时运行的Python代码,并使用Blocks中的内置事件来分配它。
同时,您可能希望使用现有的FlaggingCallback以避免编写额外的代码。
这需要两个步骤:
.setup(),在你第一次标记数据之前.flag()方法,确保正确收集参数并禁用典型的预处理。这是一个带有图像棕褐色滤镜的示例块演示,允许您使用默认的CSVLogger标记数据:
import numpy as np
import gradio as gr
def sepia(input_img, strength):
sepia_filter = strength * np.array(
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
) + (1-strength) * np.identity(3)
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
callback = gr.CSVLogger()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image()
strength = gr.Slider(0, 1, 0.5)
img_output = gr.Image()
with gr.Row():
btn = gr.Button("Flag")
# This needs to be called at some point prior to the first call to callback.flag()
callback.setup([img_input, strength, img_output], "flagged_data_points")
img_input.change(sepia, [img_input, strength], img_output)
strength.change(sepia, [img_input, strength], img_output)
# We can choose which components to flag -- in this case, we'll flag all of them
btn.click(lambda *args: callback.flag(list(args)), [img_input, strength, img_output], None, preprocess=False)
demo.launch()
重要提示:请确保您的用户了解他们提交的数据何时被保存,以及您计划如何处理这些数据。当您使用flagging_mode=auto(当通过演示提交的所有数据都被标记时)时,这一点尤为重要。