4. 计算技巧#
4.1. 即时编译#
即时编译(JIT)是一种在JAX库中用于优化和加速数值计算执行的技术。因此,它可以使skscope中的求解器更高效地执行。我们可以通过在solve方法中设置jit=True来轻松使用JIT加速求解器的执行:
from skscope import ScopeSolver
def objective_fn(params):
value = 0.0
# do somethings
return value
solver = ScopeSolver(
dimensionality=10, ## there are p parameters
sparsity=3, ## the candidate support sizes
)
solver.solve(objective_fn, jit=True)
在JIT模式开启或关闭的情况下,运行时间的比较显示,JIT通常可以加速计算,速度提升范围在2到30倍之间。以下是不同求解器在不同问题上非JIT模式与JIT模式的运行时间比率:
线性回归 |
逻辑回归 |
多任务学习 |
非线性特征选择 |
趋势过滤 |
伊辛模型 |
|
|---|---|---|---|---|---|---|
FobaSolver |
11.93 |
19.16 |
7.02 |
4.32 |
2.97 |
14.76 |
GraspSolver |
5.76 |
31.63 |
6.73 |
1.07 |
1.81 |
10.34 |
HTPSolver |
5.34 |
13.55 |
11.16 |
1.21 |
0.89 |
13.26 |
IHTSolver |
1.06 |
3.28 |
1.84 |
0.53 |
0.25 |
4.89 |
OMPSolver |
11.33 |
20.88 |
9.82 |
2.83 |
0.9 |
16.16 |
ScopeSolver |
5.24 |
17.26 |
2.06 |
2.01 |
3.21 |
11.21 |
> 请注意,JIT 对目标函数的编程有额外的要求。更多详细信息可以在 JAX 文档 中找到。
4.2. 支持GPU设备#
skscope 不排除使用GPU设备进行计算。
事实上,当用户正确安装与物理设备匹配的JAX时,他们可以在没有任何额外命令的情况下使用GPU进行计算。
> JAX 在 GPU 或 TPU 上透明运行(如果没有 GPU 或 TPU,则回退到 CPU)。
为了确保通用性,skscope 依赖于仅CPU版本的JAX。
因此,对于希望使用GPU的用户,他们只需按照说明
正确安装与物理设备匹配的JAX版本。例如:
pip install -U "jax[cuda12]"
4.3. 支持稀疏矩阵#
感谢jax,skscope支持将输入矩阵作为稀疏矩阵。尽管使用稀疏矩阵会增加自动微分所需的时间,但它可以显著减少内存使用。下面,我们提供了一个线性回归的例子来展示这一功能。首先,我们导入必要的库并过滤掉警告以获得更清晰的输出。
import numpy as np
import jax.numpy as jnp
from jax.experimental import sparse
from skscope import ScopeSolver
import scipy.sparse as sp
import warnings
warnings.filterwarnings('ignore')
接下来,我们使用scipy.sparse生成一个随机稀疏矩阵,使用JAX将其转换为密集矩阵,然后将其转换为BCOO格式的稀疏矩阵。我们还基于预定义的向量创建了一个目标向量,并添加了一些噪声。
n, p = 150, 30
np.random.seed(0)
random_sparse_matrix = sp.random(n, p, density=0.1, format='coo')
dense_matrix = jnp.array(random_sparse_matrix.toarray())
X = sparse.BCOO.fromdense(dense_matrix)
beta = np.zeros(p)
beta[:3] = [1, 2, 3]
y = X @ beta + np.random.normal(0, 0.1, n)
我们定义了一个简单的普通最小二乘(OLS)损失函数,由ScopeSolver进行最小化。
def ols_loss(params):
loss = jnp.mean((y - X @ params) ** 2)
return loss
最后,我们初始化ScopeSolver,指定要选择的特征数量,并求解最优参数。
solver = ScopeSolver(p, sparsity=3)
params_skscope = solver.solve(ols_loss, jit=True)
然后,我们可以得到 params_skscope 作为子集选择的结果。