保存模型#

save_model(sktime_model, path, conda_env=None, code_paths=None, mlflow_model=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, serialization_format='pickle')[源代码][源代码]#

将 sktime 模型保存到本地文件系统中的路径。

参数:
sktime_model

拟合的 sktime 模型对象。

路径str

模型要保存的本地路径。

conda_envUnion[dict, str], 可选 (默认=None)

一个Conda环境的字典表示,或者是一个Conda环境yaml文件的路径。

代码路径类似数组,可选(默认=None)

Python 文件依赖项(或包含文件依赖项的目录)的本地文件系统路径列表。这些文件在加载模型时会被 前置 到系统路径中。

mlflow_model: mlflow.models.Model, 可选 (默认=None)

要添加 python_function 风格的 mlflow.models.Model 配置。

签名mlflow.models.signature.ModelSignature, 可选 (默认=None)

Model Signature mlflow.models.ModelSignature 描述了模型的输入和输出 Schema。模型签名可以从具有有效模型输入(例如,省略目标列的训练数据集)和有效模型输出(例如,在训练数据集上生成的模型预测)的数据集中 推断,例如:

from mlflow.models.signature import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)

警告

如果在使用 sktime 模型进行概率预测(predict_intervalpredict_quantiles)时,由于使用这些方法时返回的预测对象的 Pandas MultiIndex 列类型,签名将无法正确推断。不过,如果使用模型的 pyfunc 风格,infer_schema 将正常工作。

输入示例Union[pandas.core.frame.DataFrame, numpy.ndarray, dict, list, csr_matrix, csc_matrix], 可选
(默认=无)

输入示例提供了一个或多个有效的模型输入实例。该示例可以用作向模型提供什么数据的提示。给定的示例将被转换为 Pandas DataFrame,然后使用 Pandas 的面向分割的格式序列化为 json。字节被 base64 编码。

pip_requirementsUnion[Iterable, str], 可选 (默认=None)

可以是pip需求字符串的可迭代对象(例如[“sktime”, “-r requirements.txt”, “-c constraints.txt”]),或者是本地文件系统上pip需求文件的字符串路径(例如”requirements.txt”)

extra_pip_requirementsUnion[Iterable, str], 可选 (默认=None)

要么是一个包含pip需求字符串的可迭代对象(例如 [“pandas”, “-r requirements.txt”, “-c constraints.txt”]),要么是本地文件系统上pip需求文件的字符串路径(例如 “requirements.txt”)

序列化格式str, 可选 (默认值=”pickle”)

模型的序列化格式。这应该是 “pickle” 或 “cloudpickle” 格式之一。

参考文献

[1]

https://www.mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.Model.save

示例

>>> from sktime.datasets import load_airline  
>>> from sktime.forecasting.arima import ARIMA  
>>> from sktime.utils import mlflow_sktime  
>>> y = load_airline()  
>>> forecaster = ARIMA(  
...     order=(1, 1, 0),
...     seasonal_order=(0, 1, 0, 12),
...     suppress_warnings=True)
>>> forecaster.fit(y)  
ARIMA(...)
>>> model_path = "model"  
>>> mlflow_sktime.save_model(  
...     sktime_model=forecaster,
...     path=model_path)  
>>> loaded_model = mlflow_sktime.load_model(model_uri=model_path)  
>>> loaded_model.predict(fh=[1, 2, 3])