4. 计算技巧#

4.1. 即时编译#

即时编译(JIT)是一种在JAX库中用于优化和加速数值计算执行的技术。因此,它可以使skscope中的求解器更高效地执行。我们可以通过在solve方法中设置jit=True来轻松使用JIT加速求解器的执行:

from skscope import ScopeSolver

def objective_fn(params):
    value = 0.0
    # do somethings
    return value
solver = ScopeSolver(
    dimensionality=10,  ## there are p parameters
    sparsity=3,         ## the candidate support sizes
)
solver.solve(objective_fn, jit=True)

在JIT模式开启或关闭的情况下,运行时间的比较显示,JIT通常可以加速计算,速度提升范围在2到30倍之间。以下是不同求解器在不同问题上非JIT模式与JIT模式的运行时间比率:

线性回归

逻辑回归

多任务学习

非线性特征选择

趋势过滤

伊辛模型

FobaSolver

11.93

19.16

7.02

4.32

2.97

14.76

GraspSolver

5.76

31.63

6.73

1.07

1.81

10.34

HTPSolver

5.34

13.55

11.16

1.21

0.89

13.26

IHTSolver

1.06

3.28

1.84

0.53

0.25

4.89

OMPSolver

11.33

20.88

9.82

2.83

0.9

16.16

ScopeSolver

5.24

17.26

2.06

2.01

3.21

11.21

> 请注意,JIT 对目标函数的编程有额外的要求。更多详细信息可以在 JAX 文档 中找到。

4.2. 支持GPU设备#

skscope 不排除使用GPU设备进行计算。 事实上,当用户正确安装与物理设备匹配的JAX时,他们可以在没有任何额外命令的情况下使用GPU进行计算。

> JAX 在 GPU 或 TPU 上透明运行(如果没有 GPU 或 TPU,则回退到 CPU)。

为了确保通用性,skscope 依赖于仅CPU版本的JAX。 因此,对于希望使用GPU的用户,他们只需按照说明 正确安装与物理设备匹配的JAX版本。例如:

pip install -U "jax[cuda12]"

4.3. 支持稀疏矩阵#

感谢jaxskscope支持将输入矩阵作为稀疏矩阵。尽管使用稀疏矩阵会增加自动微分所需的时间,但它可以显著减少内存使用。下面,我们提供了一个线性回归的例子来展示这一功能。首先,我们导入必要的库并过滤掉警告以获得更清晰的输出。

import numpy as np
import jax.numpy as jnp
from jax.experimental import sparse
from skscope import ScopeSolver
import scipy.sparse as sp

import warnings
warnings.filterwarnings('ignore')

接下来,我们使用scipy.sparse生成一个随机稀疏矩阵,使用JAX将其转换为密集矩阵,然后将其转换为BCOO格式的稀疏矩阵。我们还基于预定义的向量创建了一个目标向量,并添加了一些噪声。

n, p = 150, 30
np.random.seed(0)
random_sparse_matrix = sp.random(n, p, density=0.1, format='coo')
dense_matrix = jnp.array(random_sparse_matrix.toarray())
X = sparse.BCOO.fromdense(dense_matrix)
beta = np.zeros(p)
beta[:3] = [1, 2, 3]
y = X @ beta + np.random.normal(0, 0.1, n)

我们定义了一个简单的普通最小二乘(OLS)损失函数,由ScopeSolver进行最小化。

def ols_loss(params):
    loss = jnp.mean((y - X @ params) ** 2)
    return loss

最后,我们初始化ScopeSolver,指定要选择的特征数量,并求解最优参数。

solver = ScopeSolver(p, sparsity=3)
params_skscope = solver.solve(ols_loss, jit=True)

然后,我们可以得到 params_skscope 作为子集选择的结果。