Keras 和 Tensorflow
内容
Keras 和 Tensorflow¶
包 SciKeras 将 Scikit-learn API 引入 Keras。这使得 Dask-ML 能够与 Keras 模型无缝使用。
安装¶
按照 Tensorflow 安装指南 和 SciKeras 安装指南 ,需要安装以下包:
$ pip install tensorflow>=2.3.0
$ pip install scikeras>=0.1.8
这些是 Dask-ML 使用 Tensorflow/Keras 所需的最低版本。
用法¶
首先,让我们从定义一个普通函数开始来创建我们的模型。这是创建 Keras Sequential 模型 的常规方法。
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
def build_model(lr=0.01, momentum=0.9):
layers = [Dense(512, input_shape=(784,), activation="relu"),
Dense(10, input_shape=(512,), activation="softmax")]
model = Sequential(layers)
opt = tf.keras.optimizers.SGD(
learning_rate=lr, momentum=momentum, nesterov=True,
)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
return model
现在,我们可以使用 SciKeras 来创建一个与 Scikit-learn 兼容的模型:
from scikeras.wrappers import KerasClassifier
niceties = dict(verbose=False)
model = KerasClassifier(build_fn=build_model, lr=0.1, momentum=0.9, **niceties)
该模型将与所有 Dask-ML 兼容:它可以使用 NumPy 数组作为输入,并遵循 Scikit-learn API。例如,可以使用 Dask-ML 执行以下操作:
使用 Keras 与 Dask-ML 的模型选择,包括
HyperbandSearchCV。使用 Keras 与 Dask-ML 的
Incremental。
如果我们想要调整 lr 和 momentum,SciKeras 要求我们在初始化时传递 lr 和 momentum:
model = KerasClassifier(build_fn=build_model, lr=None, momentum=None, **niceties)
SciKeras 支持更多模型创建方法,包括一些与 Tensorflow 向后兼容的方法。详情请参阅其文档。
示例:超参数优化¶
如果我们愿意,我们可以使用上面的模型与 HyperbandSearchCV。让我们在MNIST数据集上调整这个模型:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
from typing import Tuple
def get_mnist() -> Tuple[np.ndarray, np.ndarray]:
(X_train, y_train), _ = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 784)
X_train = X_train.astype("float32")
X_train /= 255
return X_train, y_train
让我们执行调整SGD实现的基本任务:
from scipy.stats import loguniform, uniform
params = {"lr": loguniform(1e-3, 1e-1), "momentum": uniform(0, 1)}
X, y = get_mnist()
现在,可以运行搜索了:
from dask.distributed import Client
client = Client()
from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, max_iter=27)
search.fit(X, y)