speechbrain.utils.kmeans 模块

用于训练kmeans模型的实用工具。

Author
  • Pooneh Mousavi 2023

摘要

函数:

accumulate_and_extract_features

提取特征(SSL模型的输出)并在CPU上累积它们以用于聚类。

fetch_kmeans_model

返回一个具有指定参数的k-means聚类模型。

process_chunks

以指定大小的块处理数据。

save_model

保存一个Kmeans模型。

train

训练一个Kmeans模型。

参考

speechbrain.utils.kmeans.accumulate_and_extract_features(batch, features_list, ssl_model, ssl_layer_num, device)[source]

提取特征(SSL模型的输出)并在CPU上累积它们以用于聚类。

Parameters:
  • batch (tensor) – 数据的单个批次。

  • features_list (list) – 累积特征列表。

  • ssl_model (torch.nn.Module) – 用于提取聚类特征的SSL模型。

  • ssl_layer_num (int) – 指定应使用ssl_model的哪一层的输出。

  • device (str) – CPU 或 GPU。

speechbrain.utils.kmeans.fetch_kmeans_model(n_clusters, init, max_iter, batch_size, tol, max_no_improvement, n_init, reassignment_ratio, random_state, checkpoint_path)[source]

返回一个具有指定参数的k-means聚类模型。

Parameters:
  • n_clusters (MiniBatchKMeans) – 要形成的簇的数量以及要生成的质心的数量。

  • init (int) – 初始化方法:{‘k-means++’’, ‘’random’’}

  • max_iter (int) – 在停止之前,对整个数据集进行迭代的最大次数,独立于任何早期停止准则启发式方法。

  • batch_size (int) – 小批量的大小。

  • tol (float) – 通过平滑的、方差归一化的均值中心平方位置变化的相对中心变化来控制早期停止。

  • max_no_improvement (int) – 控制基于连续的小批量数据未能在平滑惯性上产生改进的早期停止。

  • n_init (int) – 尝试的随机初始化次数

  • reassignment_ratio (float) – 控制中心重新分配的最大计数的比例。

  • random_state (int) – 确定用于质心初始化和随机重新分配随机数生成。

  • checkpoint_path (str) – 保存模型的路径。

Returns:

一个具有指定参数的k-means聚类模型。

Return type:

MiniBatchKMeans

speechbrain.utils.kmeans.process_chunks(data, chunk_size, model)[source]

以指定大小的块处理数据。

Parameters:
  • data (list) – 要处理的整数列表。

  • chunk_size (int) – 每个块的大小。

  • model (MiniBatchKMeans) – 用于训练的初始kmeans模型。

Returns:

model – 用于训练的初始kmeans模型。

Return type:

MiniBatchKMeans

speechbrain.utils.kmeans.train(model, train_set, ssl_model, save_path, ssl_layer_num, kmeans_batch_size=1000, device='cpu', checkpoint_interval=10)[source]

训练一个Kmeans模型。

Parameters:
  • model (MiniBatchKMeans) – 用于训练的初始kmeans模型。

  • train_set (Dataloader) – 训练数据的批次。

  • ssl_model – 用于提取用于聚类的特征的SSL模型。

  • save_path (string) – 保存中间检查点和数据加载器的路径。

  • ssl_layer_num (int) – 指定应使用ssl_model的哪一层的输出。

  • device – CPU 或 GPU。

  • kmeans_batch_size (int) – 小批量的大小。

  • checkpoint_interval (int) – 确定在哪些迭代时保存检查点。

speechbrain.utils.kmeans.save_model(model, checkpoint_path)[source]

保存一个Kmeans模型。

Parameters:
  • model (MiniBatchKMeans) – 要保存的kmeans模型。

  • checkpoint_path (str) – 保存模型的路径。