注意
已弃用的@st.cache装饰器的文档可以在使用st.cache优化性能中找到。
Caching overview
Streamlit 在每次用户交互或代码更改时从上到下运行您的脚本。这种执行模型使开发变得非常容易。但它带来了两个主要挑战:
- 长时间运行的函数会反复运行,这会减慢您的应用程序。
- 对象会一次又一次地重新创建,这使得在重新运行或会话之间持久化它们变得困难。
但别担心!Streamlit 允许你通过其内置的缓存机制来解决这两个问题。缓存会存储慢速函数调用的结果,因此它们只需要运行一次。这使你的应用程序更快,并有助于在重新运行时持久化对象。缓存的值对你的应用程序的所有用户都可用。如果你需要保存只能在会话中访问的结果,请改用Session State。
Minimal example
要在Streamlit中缓存一个函数,你必须使用两个装饰器之一(st.cache_data 或 st.cache_resource)来装饰它:
@st.cache_data
def long_running_function(param1, param2):
return …
在这个例子中,使用@st.cache_data装饰long_running_function告诉Streamlit,每当调用该函数时,它会检查两件事:
- 输入参数的值(在这种情况下,
param1和param2)。 - 函数内部的代码。
如果这是Streamlit第一次看到这些参数值和函数代码,它会运行该函数并将返回值存储在缓存中。下次使用相同参数和代码调用该函数时(例如,当用户与应用交互时),Streamlit将完全跳过执行函数并返回缓存的值。在开发过程中,缓存会随着函数代码的变化自动更新,确保缓存中反映最新的更改。
如前所述,有两个缓存装饰器:
st.cache_data是推荐用于缓存返回数据的计算的方式:从CSV加载DataFrame,转换NumPy数组,查询API,或任何其他返回可序列化数据对象(str, int, float, DataFrame, array, list, …)的函数。它在每次函数调用时创建数据的新副本,使其免受突变和竞争条件的影响。st.cache_data的行为在大多数情况下是您想要的——所以如果您不确定,请从st.cache_data开始,看看它是否有效!st.cache_resource是推荐用于缓存全局资源(如ML模型或数据库连接)的方式——这些是不可序列化的对象,您不希望多次加载。使用它,您可以在应用程序的所有重新运行和会话中共享这些资源,而无需复制或重复。请注意,对缓存返回值的任何修改都会直接修改缓存中的对象(更多详细信息见下文)。

Streamlit的两个缓存装饰器及其使用案例。
Basic usage
st.cache_data
st.cache_data 是您用于所有返回数据的函数的首选命令——无论是DataFrames、NumPy数组、str、int、float还是其他可序列化类型。它几乎适用于所有用例!在每个用户会话中,使用@st.cache_data装饰的函数会返回缓存返回值的副本(如果该值已被缓存)。
用法
让我们看一个使用st.cache_data的例子。假设你的应用程序从互联网加载Uber拼车数据集——一个50 MB的CSV文件——到一个DataFrame中:
def load_data(url):
df = pd.read_csv(url) # 👈 Download the data
return df
df = load_data("https://github.com/plotly/datasets/raw/master/uber-rides-data1.csv")
st.dataframe(df)
st.button("Rerun")
运行load_data函数需要2到30秒,具体取决于您的网络连接。(提示:如果您使用的是慢速连接,请使用这个5 MB的数据集代替)。如果没有缓存,每次加载应用程序或用户交互时都会重新下载。通过点击我们添加的按钮亲自尝试一下!这不是一个很好的体验… 😕
现在让我们在 load_data 上添加 @st.cache_data 装饰器:
@st.cache_data # 👈 Add the caching decorator
def load_data(url):
df = pd.read_csv(url)
return df
df = load_data("https://github.com/plotly/datasets/raw/master/uber-rides-data1.csv")
st.dataframe(df)
st.button("Rerun")
再次运行应用程序。您会注意到,缓慢的下载只会在第一次运行时发生。每次后续的重新运行应该几乎是即时的!💨
行为
这是如何工作的?让我们一步步了解st.cache_data的行为:
- 在第一次运行时,Streamlit 发现它从未使用指定的参数值(CSV文件的URL)调用过
load_data函数。因此,它运行该函数并下载数据。 - 现在我们的缓存机制开始生效:返回的DataFrame通过pickle进行序列化(转换为字节)并存储在缓存中(与
url参数的值一起)。 - 在下一次运行时,Streamlit 会检查缓存中是否有特定
url的load_data条目。有一个!因此它会检索缓存的对象,将其反序列化为 DataFrame,并返回它,而不是重新运行函数并再次下载数据。
序列化和反序列化缓存对象的过程会创建我们原始DataFrame的副本。虽然这种复制行为可能看起来不必要,但当我们缓存数据对象时,这正是我们想要的,因为它有效地防止了变异和并发问题。阅读下面的“变异和并发问题”部分以更详细地理解这一点。
警告
st.cache_data 隐式使用了 pickle 模块,该模块已知是不安全的。您缓存的函数返回的任何内容都会被序列化并存储,然后在检索时反序列化。请确保您的缓存函数返回可信的值,因为有可能构造恶意的序列化数据,在反序列化期间执行任意代码。切勿以不安全模式加载可能来自不受信任来源的数据,或者可能被篡改的数据。只加载您信任的数据。
示例
数据框转换
在上面的例子中,我们已经展示了如何缓存加载一个DataFrame。缓存DataFrame的转换操作,如df.filter、df.apply或df.sort_values,也可能非常有用。特别是对于大型DataFrame,这些操作可能会很慢。
@st.cache_data
def transform(df):
df = df.filter(items=['one', 'three'])
df = df.apply(np.sum, axis=0)
return df
数组计算
同样地,缓存NumPy数组上的计算也是有意义的:
@st.cache_data
def add(arr1, arr2):
return arr1 + arr2
数据库查询
在使用数据库时,您通常通过SQL查询将数据加载到应用程序中。重复运行这些查询可能会很慢,花费金钱,并降低数据库的性能。我们强烈建议在应用程序中缓存任何数据库查询。另请参阅我们的指南,了解如何将Streamlit连接到不同的数据库,以获取深入的示例。
connection = database.connect()
@st.cache_data
def query():
return pd.read_sql_query("SELECT * from table", connection)
提示
你应该设置一个ttl(生存时间)以从数据库中获取新结果。如果你设置st.cache_data(ttl=3600),Streamlit 将在1小时(3600秒)后使任何缓存的值失效,并再次运行缓存的函数。详情请参见控制缓存大小和持续时间。
API调用
同样地,缓存API调用也是有意义的。这样做还可以避免速率限制。
@st.cache_data
def api_call():
response = requests.get('https://jsonplaceholder.typicode.com/posts/1')
return response.json()
运行机器学习模型(推理)
运行复杂的机器学习模型可能会消耗大量的时间和内存。为了避免重复运行相同的计算,请使用缓存。
@st.cache_data
def run_model(inputs):
return model(inputs)
st.cache_resource
st.cache_resource 是用于缓存“资源”的正确命令,这些资源应该在全球范围内对所有用户、会话和重新运行都可用。它的使用场景比 st.cache_data 更有限,特别是用于缓存数据库连接和机器学习模型。在每个用户会话中,使用 @st.cache_resource 装饰的函数会返回缓存中的返回值实例(如果该值已经被缓存)。因此,由 st.cache_resource 缓存的对象表现得像单例,并且可以发生变异。
用法
作为st.cache_resource的一个示例,让我们来看一个典型的机器学习应用程序。首先,我们需要加载一个机器学习模型。我们使用Hugging Face的transformers库来完成这一步骤:
from transformers import pipeline
model = pipeline("sentiment-analysis") # 👈 Load the model
如果我们直接将这段代码放入Streamlit应用程序中,应用程序将在每次重新运行或用户交互时加载模型。重复加载模型会带来两个问题:
- 加载模型需要时间并会减慢应用程序的速度。
- 每个会话从头加载模型,这会占用大量内存。
相反,更有意义的是加载模型一次,并在所有用户和会话中使用相同的对象。这正是st.cache_resource的用例!让我们将其添加到我们的应用程序中,并处理用户输入的一些文本:
from transformers import pipeline
@st.cache_resource # 👈 Add the caching decorator
def load_model():
return pipeline("sentiment-analysis")
model = load_model()
query = st.text_input("Your query", value="I love Streamlit! 🎈")
if query:
result = model(query)[0] # 👈 Classify the query text
st.write(result)
如果你运行这个应用程序,你会看到应用程序只调用一次load_model——就在应用程序启动时。后续运行将重用存储在缓存中的相同模型,节省时间和内存!
行为
使用 st.cache_resource 与使用 st.cache_data 非常相似。但在行为上有一些重要的区别:
-
st.cache_resource不会创建缓存返回值的副本,而是将对象本身存储在缓存中。对函数返回值的所有修改都会直接影响缓存中的对象,因此您必须确保来自多个会话的修改不会导致问题。简而言之,返回值必须是线程安全的。priority_high 警告
在非线程安全的对象上使用
st.cache_resource可能会导致崩溃或数据损坏。了解更多信息,请参阅 Mutation and concurrency issues。 -
不创建副本意味着只有一个缓存的返回对象的全局实例,这样可以节省内存,例如在使用大型机器学习模型时。用计算机科学的术语来说,我们创建了一个单例。
-
函数的返回值不需要是可序列化的。这种行为对于本质上不可序列化的类型非常有利,例如数据库连接、文件句柄或线程。使用
st.cache_data缓存这些对象是不可能的。
示例
数据库连接
st.cache_resource 对于连接数据库非常有用。通常,您会创建一个连接对象,希望在每个查询中全局重用。每次运行时创建新的连接对象效率低下,并可能导致连接错误。这正是 st.cache_resource 可以做到的,例如,对于 Postgres 数据库:
@st.cache_resource
def init_connection():
host = "hh-pgsql-public.ebi.ac.uk"
database = "pfmegrnargs"
user = "reader"
password = "NWDMCE5xdipIjRrp"
return psycopg2.connect(host=host, database=database, user=user, password=password)
conn = init_connection()
当然,您也可以对其他数据库执行相同的操作。请查看我们的指南,了解如何将Streamlit连接到数据库,以获取深入的示例。
加载机器学习模型
您的应用程序应始终缓存ML模型,以便它们不会为每个新会话再次加载到内存中。请参阅上面的示例,了解这与🤗 Hugging Face模型的配合方式。您可以为PyTorch、TensorFlow等做同样的事情。以下是一个PyTorch的示例:
@st.cache_resource
def load_model():
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()
return model
model = load_model()
Deciding which caching decorator to use
上面的部分展示了每个缓存装饰器的许多常见示例。但是有一些边缘情况,决定使用哪个缓存装饰器并不那么简单。最终,这一切都归结为“数据”和“资源”之间的区别:
- 数据是可序列化的对象(可以通过pickle转换为字节的对象),您可以轻松地将其保存到磁盘。想象一下您通常会存储在数据库或文件系统中的所有类型——基本类型如str、int和float,还有数组、DataFrames、图像或这些类型的组合(列表、元组、字典等)。
- 资源是不可序列化的对象,通常不会保存到磁盘或数据库中。它们通常是更复杂、非永久性的对象,如数据库连接、ML模型、文件句柄、线程等。
从上面列出的类型来看,很明显Python中的大多数对象都是“数据”。这也是为什么st.cache_data是几乎所有用例的正确命令。st.cache_resource是一个更特殊的命令,你应该只在特定情况下使用。
或者如果你懒得不想想太多,可以在下面的表格中查找你的用例或返回类型 😉:
| 使用案例 | 典型返回类型 | 缓存装饰器 |
|---|---|---|
| 使用 pd.read_csv 读取 CSV 文件 | pandas.DataFrame | st.cache_data |
| 读取文本文件 | str, list of str | st.cache_data |
| 转换pandas数据框 | pandas.DataFrame, pandas.Series | st.cache_data |
| 使用numpy数组进行计算 | numpy.ndarray | st.cache_data |
| 使用基本类型进行简单计算 | str, int, float, … | st.cache_data |
| 查询数据库 | pandas.DataFrame | st.cache_data |
| 查询API | pandas.DataFrame, str, dict | st.cache_data |
| 运行一个机器学习模型(推理) | pandas.DataFrame, str, int, dict, list | st.cache_data |
| 创建或处理图像 | PIL.Image.Image, numpy.ndarray | st.cache_data |
| 创建图表 | matplotlib.figure.Figure, plotly.graph_objects.Figure, altair.Chart | st.cache_data(但有些库需要st.cache_resource,因为图表对象不可序列化——确保在创建后不要修改图表!) |
| 加载机器学习模型 | transformers.Pipeline, torch.nn.Module, tensorflow.keras.Model | st.cache_resource |
| 初始化数据库连接 | pyodbc.Connection, sqlalchemy.engine.base.Engine, psycopg2.connection, mysql.connector.MySQLConnection, sqlite3.Connection | st.cache_resource |
| 打开持久文件句柄 | _io.TextIOWrapper | st.cache_resource |
| 打开持久线程 | threading.thread | st.cache_resource |
Advanced usage
Controlling cache size and duration
如果你的应用程序运行时间很长并且不断缓存函数,你可能会遇到两个问题:
- 应用程序因为缓存太大而耗尽内存。
- 缓存中的对象变得过时,例如,因为你缓存了数据库中的旧数据。
你可以使用ttl和max_entries参数来解决这些问题,这些参数适用于两种缓存装饰器。
ttl(生存时间)参数
ttl 设置缓存函数的生存时间。如果时间到了并且你再次调用该函数,应用程序将丢弃任何旧的缓存值,并重新运行该函数。新计算的值将存储在缓存中。这种行为对于防止数据过时(问题2)和缓存变得过大(问题1)非常有用。特别是在从数据库或API中提取数据时,你应该始终设置一个ttl,这样你就不会使用旧数据。以下是一个示例:
@st.cache_data(ttl=3600) # 👈 Cache data for 1 hour (=3600 seconds)
def get_api_data():
data = api.get(...)
return data
提示
你也可以使用timedelta来设置ttl值,例如,ttl=datetime.timedelta(hours=1)。
max_entries 参数
max_entries 设置缓存中的最大条目数。缓存条目数的上限对于限制内存(问题1)非常有用,尤其是在缓存大对象时。当向已满的缓存添加新条目时,最旧的条目将被移除。以下是一个示例:
@st.cache_data(max_entries=1000) # 👈 Maximum 1000 entries in the cache
def get_large_array(seed):
np.random.seed(seed)
arr = np.random.rand(100000)
return arr
Customizing the spinner
默认情况下,当缓存函数运行时,Streamlit 会在应用程序中显示一个小的加载旋转器。您可以使用 show_spinner 参数轻松修改它,该参数适用于两个缓存装饰器:
@st.cache_data(show_spinner=False) # 👈 Disable the spinner
def get_api_data():
data = api.get(...)
return data
@st.cache_data(show_spinner="Fetching data from API...") # 👈 Use custom text for spinner
def get_api_data():
data = api.get(...)
return data
Excluding input parameters
在缓存函数中,所有输入参数必须是可哈希的。让我们快速解释一下原因及其含义。当函数被调用时,Streamlit会查看其参数值以确定是否之前被缓存过。因此,它需要一种可靠的方法来比较函数调用之间的参数值。对于字符串或整数来说很简单——但对于任意对象来说就很复杂了!Streamlit使用哈希来解决这个问题。它将参数转换为一个稳定的键并存储该键。在下次函数调用时,它会再次哈希参数并将其与存储的哈希键进行比较。
不幸的是,并非所有参数都是可哈希的!例如,您可能会传递一个不可哈希的数据库连接或机器学习模型到您的缓存函数中。在这种情况下,您可以从缓存中排除输入参数。只需在参数名称前加上下划线(例如,_param1),它将不会被用于缓存。即使它发生变化,如果所有其他参数匹配,Streamlit 仍将返回缓存的结果。
这是一个例子:
@st.cache_data
def fetch_data(_db_connection, num_rows): # 👈 Don't hash _db_connection
data = _db_connection.fetch(num_rows)
return data
connection = init_connection()
fetch_data(connection, 10)
但是,如果你想缓存一个带有不可哈希参数的函数怎么办?例如,你可能想缓存一个以ML模型为输入并返回该模型的层名称的函数。由于模型是唯一的输入参数,你不能将其从缓存中排除。在这种情况下,你可以使用hash_funcs参数为模型指定一个自定义的哈希函数。
The hash_funcs parameter
如上所述,Streamlit的缓存装饰器对输入参数和缓存函数的签名进行哈希处理,以确定该函数是否之前运行过并存储了返回值(“缓存命中”)或需要运行(“缓存未命中”)。对于Streamlit的哈希实现无法哈希的输入参数,可以通过在其名称前添加下划线来忽略它们。但在两种罕见情况下,这是不可取的。即,当你想要对Streamlit无法哈希的参数进行哈希处理时:
- 当 Streamlit 的哈希机制无法对参数进行哈希时,导致抛出
UnhashableParamError。 - 当您想要覆盖Streamlit的默认哈希机制以用于参数时。
让我们依次讨论这些案例,并举例说明。
示例 1: 哈希自定义类
Streamlit 不知道如何对自定义类进行哈希处理。如果你将一个自定义类传递给一个缓存函数,Streamlit 会抛出一个 UnhashableParamError。例如,我们定义一个自定义类 MyCustomClass,它接受一个初始的整数分数。我们还定义一个缓存函数 multiply_score,它将分数乘以一个乘数:
import streamlit as st
class MyCustomClass:
def __init__(self, initial_score: int):
self.my_score = initial_score
@st.cache_data
def multiply_score(obj: MyCustomClass, multiplier: int) -> int:
return obj.my_score * multiplier
initial_score = st.number_input("Enter initial score", value=15)
score = MyCustomClass(initial_score)
multiplier = 2
st.write(multiply_score(score, multiplier))
如果你运行这个应用程序,你会看到Streamlit会抛出一个UnhashableParamError,因为它不知道如何哈希MyCustomClass:
UnhashableParamError: Cannot hash argument 'obj' (of type __main__.MyCustomClass) in 'multiply_score'.
为了解决这个问题,我们可以使用hash_funcs参数来告诉Streamlit如何对MyCustomClass进行哈希。我们通过向hash_funcs传递一个字典来实现这一点,该字典将参数名称映射到一个哈希函数。哈希函数的选择由开发者决定。在这种情况下,让我们定义一个自定义哈希函数hash_func,它接受自定义类作为输入并返回分数。我们希望分数是对象的唯一标识符,因此我们可以使用它来确定性哈希对象:
import streamlit as st
class MyCustomClass:
def __init__(self, initial_score: int):
self.my_score = initial_score
def hash_func(obj: MyCustomClass) -> int:
return obj.my_score # or any other value that uniquely identifies the object
@st.cache_data(hash_funcs={MyCustomClass: hash_func})
def multiply_score(obj: MyCustomClass, multiplier: int) -> int:
return obj.my_score * multiplier
initial_score = st.number_input("Enter initial score", value=15)
score = MyCustomClass(initial_score)
multiplier = 2
st.write(multiply_score(score, multiplier))
现在如果你运行应用程序,你会看到Streamlit不再引发UnhashableParamError,并且应用程序按预期运行。
现在让我们考虑multiply_score是MyCustomClass的一个属性的情况,并且我们想要对整个对象进行哈希处理:
import streamlit as st
class MyCustomClass:
def __init__(self, initial_score: int):
self.my_score = initial_score
@st.cache_data
def multiply_score(self, multiplier: int) -> int:
return self.my_score * multiplier
initial_score = st.number_input("Enter initial score", value=15)
score = MyCustomClass(initial_score)
multiplier = 2
st.write(score.multiply_score(multiplier))
如果你运行这个应用程序,你会看到Streamlit会引发一个UnhashableParamError,因为它无法在'multiply_score'中对参数'self'(类型为__main__.MyCustomClass)进行哈希处理。这里的一个简单修复方法可能是使用Python的hash()函数来哈希对象:
import streamlit as st
class MyCustomClass:
def __init__(self, initial_score: int):
self.my_score = initial_score
@st.cache_data(hash_funcs={"__main__.MyCustomClass": lambda x: hash(x.my_score)})
def multiply_score(self, multiplier: int) -> int:
return self.my_score * multiplier
initial_score = st.number_input("Enter initial score", value=15)
score = MyCustomClass(initial_score)
multiplier = 2
st.write(score.multiply_score(multiplier))
上面,哈希函数被定义为 lambda x: hash(x.my_score)。这基于 MyCustomClass 实例的 my_score 属性创建了一个哈希。只要 my_score 保持不变,哈希值也保持不变。因此,可以从缓存中检索 multiply_score 的结果,而无需重新计算。
作为一个精明的Pythonista,你可能曾经想使用Python的id()函数来哈希对象,如下所示:
import streamlit as st
class MyCustomClass:
def __init__(self, initial_score: int):
self.my_score = initial_score
@st.cache_data(hash_funcs={"__main__.MyCustomClass": id})
def multiply_score(self, multiplier: int) -> int:
return self.my_score * multiplier
initial_score = st.number_input("Enter initial score", value=15)
score = MyCustomClass(initial_score)
multiplier = 2
st.write(score.multiply_score(multiplier))
如果你运行这个应用程序,你会注意到即使my_score没有改变,Streamlit每次都会重新计算multiply_score!感到困惑吗?在Python中,id()返回一个对象的标识符,这个标识符在对象的生命周期内是唯一且不变的。这意味着即使MyCustomClass的两个实例中的my_score值相同,id()也会为这两个实例返回不同的值,从而导致不同的哈希值。因此,Streamlit认为这两个不同的实例需要单独的缓存值,因此即使my_score没有改变,它每次都会重新计算multiply_score。
这就是为什么我们不鼓励将其用作哈希函数,而是鼓励使用返回确定性、真实哈希值的函数。也就是说,如果你知道自己在做什么,你可以使用id()作为哈希函数。只是要意识到后果。例如,当你将@st.cache_resource函数的结果作为输入参数传递给另一个缓存函数时,id通常是正确的哈希函数。有一整类对象类型在其他情况下是不可哈希的。
示例 2: 哈希化一个 Pydantic 模型
让我们考虑另一个例子,我们想要哈希一个Pydantic模型:
import streamlit as st
from pydantic import BaseModel
class Person(BaseModel):
name: str
@st.cache_data
def identity(person: Person):
return person
person = identity(Person(name="Lee"))
st.write(f"The person is {person.name}")
在上面,我们使用Pydantic的BaseModel定义了一个自定义类Person,它有一个单一属性name。我们还定义了一个identity函数,它接受一个Person的实例作为参数并返回它而不做任何修改。这个函数旨在缓存结果,因此,如果使用相同的Person实例多次调用它,它不会重新计算,而是返回缓存的实例。
然而,如果你运行这个应用程序,你会遇到一个UnhashableParamError: Cannot hash argument 'person' (of type __main__.Person) in 'identity'.错误。这是因为Streamlit不知道如何哈希Person类。为了解决这个问题,我们可以使用hash_funcs参数来告诉Streamlit如何哈希Person。
在下面的版本中,我们定义了一个自定义哈希函数 hash_func,它接受 Person 实例作为输入并返回 name 属性。我们希望 name 成为对象的唯一标识符,因此我们可以使用它来确定性哈希对象:
import streamlit as st
from pydantic import BaseModel
class Person(BaseModel):
name: str
@st.cache_data(hash_funcs={Person: lambda p: p.name})
def identity(person: Person):
return person
person = identity(Person(name="Lee"))
st.write(f"The person is {person.name}")
示例 3: 哈希一个机器学习模型
在某些情况下,您可能希望将您最喜欢的机器学习模型传递给缓存函数。例如,假设您希望根据用户在应用程序中选择的模型,将TensorFlow模型传递给缓存函数。您可能会尝试这样做:
import streamlit as st
import tensorflow as tf
@st.cache_resource
def load_base_model(option):
if option == 1:
return tf.keras.applications.ResNet50(include_top=False, weights="imagenet")
else:
return tf.keras.applications.MobileNetV2(include_top=False, weights="imagenet")
@st.cache_resource
def load_layers(base_model):
return [layer.name for layer in base_model.layers]
option = st.radio("Model 1 or 2", [1, 2])
base_model = load_base_model(option)
layers = load_layers(base_model)
st.write(layers)
在上述应用程序中,用户可以选择两种模型之一。根据选择,应用程序加载相应的模型并将其传递给load_layers。然后,此函数返回模型中层的名称。如果您运行该应用程序,您会看到Streamlit引发了一个UnhashableParamError,因为它无法对'load_layers'中的参数'base_model'(类型为keras.engine.functional.Functional)进行哈希处理。
如果你通过在base_model名称前加上下划线来禁用哈希,你会观察到无论选择哪个基础模型,显示的层都是相同的。这个微妙的错误是由于当基础模型改变时,load_layers函数没有重新运行。这是因为Streamlit没有对base_model参数进行哈希处理,所以它不知道当基础模型改变时需要重新运行该函数。
为了解决这个问题,我们可以使用hash_funcs参数来告诉Streamlit如何对base_model参数进行哈希处理。在下面的版本中,我们定义了一个自定义的哈希函数hash_func:Functional: lambda x: x.name。我们选择这个哈希函数是基于我们对Functional对象或模型的name属性的了解,该属性唯一标识了它。只要name属性保持不变,哈希值也保持不变。因此,可以从缓存中检索load_layers的结果,而无需重新计算。
import streamlit as st
import tensorflow as tf
from keras.engine.functional import Functional
@st.cache_resource
def load_base_model(option):
if option == 1:
return tf.keras.applications.ResNet50(include_top=False, weights="imagenet")
else:
return tf.keras.applications.MobileNetV2(include_top=False, weights="imagenet")
@st.cache_resource(hash_funcs={Functional: lambda x: x.name})
def load_layers(base_model):
return [layer.name for layer in base_model.layers]
option = st.radio("Model 1 or 2", [1, 2])
base_model = load_base_model(option)
layers = load_layers(base_model)
st.write(layers)
在上述情况下,我们也可以使用hash_funcs={Functional: id}作为哈希函数。这是因为当你将@st.cache_resource函数的结果作为输入参数传递给另一个缓存函数时,id通常是正确的哈希函数。
示例 4:覆盖 Streamlit 的默认哈希机制
让我们考虑另一个例子,我们想要覆盖Streamlit的默认哈希机制,用于一个pytz本地化的datetime对象:
from datetime import datetime
import pytz
import streamlit as st
tz = pytz.timezone("Europe/Berlin")
@st.cache_data
def load_data(dt):
return dt
now = datetime.now()
st.text(load_data(dt=now))
now_tz = tz.localize(datetime.now())
st.text(load_data(dt=now_tz))
可能会让人感到惊讶的是,尽管now和now_tz属于相同的类型,Streamlit 却不知道如何对now_tz进行哈希处理,并引发了一个UnhashableParamError。在这种情况下,我们可以通过向hash_funcs关键字参数传递一个自定义的哈希函数来覆盖 Streamlit 对datetime对象的默认哈希机制:
from datetime import datetime
import pytz
import streamlit as st
tz = pytz.timezone("Europe/Berlin")
@st.cache_data(hash_funcs={datetime: lambda x: x.strftime("%a %d %b %Y, %I:%M%p")})
def load_data(dt):
return dt
now = datetime.now()
st.text(load_data(dt=now))
now_tz = tz.localize(datetime.now())
st.text(load_data(dt=now_tz))
现在让我们考虑一个情况,我们想要覆盖Streamlit对NumPy数组的默认哈希机制。虽然Streamlit原生支持对Pandas和NumPy对象的哈希处理,但在某些情况下,你可能想要覆盖Streamlit对这些对象的默认哈希机制。
例如,假设我们创建了一个带有缓存装饰的show_data函数,该函数接受一个NumPy数组并返回它而不进行修改。在下面的应用程序中,data = df["str"].unique()(这是一个NumPy数组)被传递给show_data函数。
import time
import numpy as np
import pandas as pd
import streamlit as st
@st.cache_data
def get_data():
df = pd.DataFrame({"num": [112, 112, 2, 3], "str": ["be", "a", "be", "c"]})
return df
@st.cache_data
def show_data(data):
time.sleep(2) # This makes the function take 2s to run
return data
df = get_data()
data = df["str"].unique()
st.dataframe(show_data(data))
st.button("Re-run")
由于data始终相同,我们期望show_data函数返回缓存的值。然而,如果你运行应用程序并点击Re-run按钮,你会注意到show_data函数每次都会重新运行。我们可以假设这种行为是Streamlit对NumPy数组默认哈希机制的结果。
为了解决这个问题,让我们定义一个自定义哈希函数 hash_func,它接受一个 NumPy 数组作为输入并返回数组的字符串表示:
import time
import numpy as np
import pandas as pd
import streamlit as st
@st.cache_data
def get_data():
df = pd.DataFrame({"num": [112, 112, 2, 3], "str": ["be", "a", "be", "c"]})
return df
@st.cache_data(hash_funcs={np.ndarray: str})
def show_data(data):
time.sleep(2) # This makes the function take 2s to run
return data
df = get_data()
data = df["str"].unique()
st.dataframe(show_data(data))
st.button("Re-run")
现在如果你运行应用程序,并点击Re-run按钮,你会注意到show_data函数不再每次都被重新运行。这里需要注意的是,我们选择的哈希函数非常朴素,不一定是最好的选择。例如,如果NumPy数组很大,将其转换为字符串表示可能会很昂贵。在这种情况下,作为开发者,你需要根据你的使用场景来定义什么是好的哈希函数。
静态元素
自版本1.16.0以来,缓存的函数可以包含Streamlit命令!例如,你可以这样做:
@st.cache_data
def get_api_data():
data = api.get(...)
st.success("Fetched data from API!") # 👈 Show a success message
return data
众所周知,Streamlit 只会在函数未被缓存之前运行此函数。在第一次运行时,st.success 消息将出现在应用程序中。但在后续运行中会发生什么?它仍然会出现!Streamlit 意识到缓存函数内部有一个 st. 命令,在第一次运行时保存它,并在后续运行中重放它。重放静态元素适用于两种缓存装饰器。
你也可以使用这个功能来缓存你的UI的整个部分:
@st.cache_data
def show_data():
st.header("Data analysis")
data = api.get(...)
st.success("Fetched data from API!")
st.write("Here is a plot of the data:")
st.line_chart(data)
st.write("And here is the raw data:")
st.dataframe(data)
输入小部件
你也可以在缓存函数中使用交互式输入小部件,比如st.slider或st.text_input。目前,小部件重放是一个实验性功能。要启用它,你需要设置experimental_allow_widgets参数:
@st.cache_data(experimental_allow_widgets=True) # 👈 Set the parameter
def get_data():
num_rows = st.slider("Number of rows to get") # 👈 Add a slider
data = api.get(..., num_rows)
return data
Streamlit 将滑块视为缓存函数的额外输入参数。如果您更改滑块位置,Streamlit 将检查是否已经为此滑块值缓存了函数。如果是,它将返回缓存的值。如果不是,它将使用新的滑块值重新运行函数。
在缓存函数中使用小部件非常强大,因为它允许您缓存应用程序的整个部分。但这可能很危险!由于Streamlit将小部件值视为额外的输入参数,它很容易导致内存使用过多。想象一下,您的缓存函数有五个滑块并返回一个100 MB的DataFrame。然后,我们将为这五个滑块值的每个排列添加100 MB到缓存中——即使滑块不影响返回的数据!这些添加可能会使您的缓存迅速爆炸。如果您在缓存函数中使用小部件,请注意此限制。我们建议仅在UI的隔离部分使用此功能,其中小部件直接影响缓存的返回值。
警告
对缓存函数中小部件的支持是实验性的。我们可能会随时更改或删除它,恕不另行通知。请谨慎使用!
注意
目前有两个小部件在缓存函数中不受支持:st.file_uploader 和 st.camera_input。我们未来可能会支持它们。如果你需要它们,请随时在GitHub上提出问题!
Dealing with large data
正如我们所解释的,你应该使用st.cache_data来缓存数据对象。但对于非常大的数据,例如超过1亿行的DataFrames或数组,这可能会很慢。这是因为st.cache_data的复制行为:在第一次运行时,它将返回值序列化为字节,并在后续运行时反序列化。这两个操作都需要时间。
如果你正在处理非常大的数据,使用st.cache_resource可能更有意义。它不会通过序列化/反序列化创建返回值的副本,几乎是即时的。但要注意:对函数返回值的任何修改(例如从DataFrame中删除一列或在数组中设置值)都会直接操作缓存中的对象。你必须确保这不会破坏你的数据或导致崩溃。请参阅下面的突变和并发问题部分。
在对具有四列的pandas DataFrames进行st.cache_data基准测试时,我们发现当行数超过1亿时,速度会变慢。下表显示了不同行数(均为四列)下两个缓存装饰器的运行时间:
| 1000万行 | 5000万行 | 1亿行 | 2亿行 | ||
|---|---|---|---|---|---|
| st.cache_data | 首次运行* | 0.4 秒 | 3 秒 | 14 秒 | 28 秒 |
| 后续运行 | 0.2 秒 | 1 秒 | 2 秒 | 7 秒 | |
| st.cache_resource | 首次运行* | 0.01 秒 | 0.1 秒 | 0.2 秒 | 1 秒 |
| 后续运行 | 0 秒 | 0 秒 | 0 秒 | 0 秒 |
| *第一次运行时,表格仅显示使用缓存装饰器的开销时间。它不包括缓存函数本身的运行时间。 |
Mutation and concurrency issues
在以上部分中,我们讨论了很多关于缓存函数返回对象变异时的问题。这个话题很复杂!但它是理解st.cache_data和st.cache_resource之间行为差异的核心。因此,让我们更深入地探讨一下。
首先,我们应该明确定义我们所说的突变和并发是什么意思:
-
通过mutations,我们指的是在调用缓存函数后对其返回值所做的任何更改。例如,像这样:
@st.cache_data def create_list(): l = [1, 2, 3] l = create_list() # 👈 调用函数 l[0] = 2 # 👈 更改其返回值 -
通过并发,我们指的是多个会话可以同时引起这些变化。Streamlit 是一个需要处理许多用户和会话连接到应用程序的 Web 框架。如果两个人同时查看一个应用程序,他们都会导致 Python 脚本重新运行,这可能会同时操作缓存的返回对象——即并发地操作。
修改缓存的返回对象可能是危险的。它可能导致应用程序中的异常,甚至损坏您的数据(这可能比应用程序崩溃更糟糕!)。下面,我们将首先解释st.cache_data的复制行为,并展示它如何避免修改问题。然后,我们将展示并发修改如何导致数据损坏以及如何防止它。
复制行为
st.cache_data 每次调用函数时都会创建缓存返回值的副本。这避免了大多数突变和并发问题。为了详细理解它,让我们回到上面关于st.cache_data部分的Uber拼车示例。我们对其进行了两项修改:
- 我们正在使用
st.cache_resource而不是st.cache_data。st.cache_resource不会创建缓存对象的副本,因此我们可以看到在没有复制行为的情况下会发生什么。 - 加载数据后,我们通过删除列
"Lat"来操作返回的DataFrame(原地操作!)。
这是代码:
@st.cache_resource # 👈 Turn off copying behavior
def load_data(url):
df = pd.read_csv(url)
return df
df = load_data("https://raw.githubusercontent.com/plotly/datasets/master/uber-rides-data1.csv")
st.dataframe(df)
df.drop(columns=['Lat'], inplace=True) # 👈 Mutate the dataframe inplace
st.button("Rerun")
让我们运行它,看看会发生什么!第一次运行应该没问题。但在第二次运行时,你会看到一个异常:KeyError: "['Lat'] not found in axis"。为什么会发生这种情况?让我们一步一步来看:
- 在第一次运行时,Streamlit 运行
load_data并将生成的 DataFrame 存储在缓存中。由于我们使用的是st.cache_resource,它不会创建副本,而是存储原始的 DataFrame。 - 然后我们从DataFrame中删除列
"Lat"。请注意,这是从存储在缓存中的原始DataFrame中删除列。我们正在操作它! - 在第二次运行时,Streamlit 从缓存中返回了完全相同的被操作过的 DataFrame。它不再有
"Lat"列了!因此,我们对df.drop的调用导致了一个异常。Pandas 无法删除一个不存在的列。
st.cache_data 的复制行为防止了这种变异错误。变异只能影响特定的副本,而不能影响缓存中的基础对象。下一次重新运行将获得其自己的、未变异的 DataFrame 副本。你可以自己尝试,只需将上面的 st.cache_resource 替换为 st.cache_data,你会发现一切正常。
由于这种复制行为,st.cache_data 是推荐用于缓存数据转换和计算的方式——任何返回可序列化对象的内容。
并发问题
现在让我们看看当多个用户同时修改缓存中的对象时会发生什么。假设你有一个返回列表的函数。再次,我们使用st.cache_resource来缓存它,这样我们就不会创建副本:
@st.cache_resource
def create_list():
l = [1, 2, 3]
return l
l = create_list()
first_list_value = l[0]
l[0] = first_list_value + 1
st.write("l[0] is:", l[0])
假设用户 A 运行该应用程序。他们将看到以下输出:
l[0] is: 2
假设另一个用户 B 在之后立即访问该应用程序。与用户 A 不同,他们将看到以下输出:
l[0] is: 3
现在,用户A在用户B之后立即重新运行应用程序。他们将看到以下输出:
l[0] is: 4
这里发生了什么?为什么所有输出都不同?
- 当用户A访问应用程序时,
create_list()被调用,列表[1, 2, 3]被存储在缓存中。然后这个列表被返回给用户A。列表的第一个值1被赋值给first_list_value,并且l[0]被更改为2。 - 当用户B访问应用程序时,
create_list()从缓存中返回已变异的列表:[2, 2, 3]。列表的第一个值2被分配给first_list_value,并且l[0]被更改为3。 - 当用户A重新运行应用程序时,
create_list()再次返回已变异的列表:[3, 2, 3]。列表的第一个值3被赋值给first_list_value,,并且l[0]被更改为 4。
如果你仔细想想,这是有道理的。用户A和用户B使用相同的列表对象(存储在缓存中的那个)。由于列表对象是可变的,用户A对列表对象的更改也会反映在用户B的应用程序中。
这就是为什么你必须小心处理使用st.cache_resource缓存的对象,特别是当多个用户同时访问应用程序时。如果我们使用了st.cache_data而不是st.cache_resource,应用程序会为每个用户复制列表对象,上面的例子就会按预期工作——用户A和B都会看到:
l[0] is: 2
注意
这个玩具示例可能看起来无害。但数据损坏可能极其危险!想象一下,如果我们在这里处理一家大型银行的财务记录。你肯定不希望因为有人使用了错误的缓存装饰器而醒来时账户里的钱变少了😉
Migrating from st.cache
我们在Streamlit 1.18.0中引入了上述缓存命令。在此之前,我们有一个通用的命令st.cache。使用它经常令人困惑,导致奇怪的异常,并且速度较慢。这就是为什么我们在1.18.0版本中用新命令替换了st.cache(更多信息请参阅此博客文章)。新命令提供了一种更直观和高效的方式来缓存您的数据和资源,并旨在在所有新开发中取代st.cache。
如果你的应用仍然在使用 st.cache,不要绝望!以下是一些迁移的注意事项:
- 如果你的应用使用了
st.cache,Streamlit 将会显示一个弃用警告。 - 我们不会很快移除
st.cache,所以你不需要担心你两年前的应用会崩溃。但我们鼓励你尝试新的命令——它们会少很多麻烦! - 在大多数情况下,将代码切换到新命令应该很容易。要决定是使用
st.cache_data还是st.cache_resource,请阅读Deciding which caching decorator to use。Streamlit 还会识别常见的使用场景,并在弃用警告中直接显示提示。 - 大多数来自
st.cache的参数也存在于新命令中,但有一些例外:allow_output_mutation不再存在。你可以安全地删除它。只需确保你为你的用例使用正确的缓存命令。suppress_st_warning不再存在。你可以安全地删除它。缓存函数现在可以包含Streamlit命令,并将重放它们。如果你想在缓存函数中使用小部件,请设置experimental_allow_widgets=True。参见输入小部件以获取示例。
如果您在迁移过程中有任何问题或疑问,请在论坛上联系我们,我们将很乐意帮助您。🎈
还有问题吗?
我们的 论坛 充满了有用的信息和Streamlit专家。