StratifiedGroupKFold#

class sklearn.model_selection.StratifiedGroupKFold(n_splits=5, shuffle=False, random_state=None)#

分层K折迭代器变种,具有不重叠的组。

这种交叉验证对象是StratifiedKFold的一种变体,试图返回具有不重叠组的分层折叠。折叠是通过保留每个类别的样本百分比来制作的。

每个组在所有折叠中将恰好出现在测试集中一次(不同组的数量至少要等于折叠的数量)。

GroupKFoldStratifiedGroupKFold 之间的区别在于,前者试图创建平衡的折叠,使得每个折叠中的不同组数量大致相同,而StratifiedGroupKFold试图创建尽可能保留每个类别样本百分比的折叠,同时在分割之间保持不重叠的组。

更多信息请参阅 用户指南

有关交叉验证行为的可视化以及常见scikit-learn分割方法之间的比较,请参阅 在 scikit-learn 中可视化交叉验证行为

Parameters:
n_splitsint, default=5

折叠数量。必须至少为2。

shufflebool, default=False

在分成批次之前是否对每个类别的样本进行洗牌。请注意,每个分割内的样本不会被洗牌。此实现只能对具有大致相同y分布的组进行洗牌,不会进行全局洗牌。

random_stateint or RandomState instance, default=None

shuffle 为True时, random_state 影响索引的顺序,这控制每个类别每个折叠的随机性。否则,将 random_state 保留为 None 。传递一个int以在多次函数调用中获得可重复的输出。请参阅 Glossary

See also

StratifiedKFold

考虑类别信息以构建保留类别分布的折叠(用于二分类或多分类任务)。

GroupKFold

具有不重叠组的K折迭代器变种。

Notes

该实现旨在:

  • 尽可能模拟StratifiedKFold的行为,对于平凡的组(例如,当每个组仅包含一个样本时)。

  • 对类别标签不变:将 y = ["Happy", "Sad"] 重新标记为 y = [1, 0] 不应改变生成的索引。

  • 尽可能基于样本进行分层,同时保持不重叠组的约束。这意味着在某些情况下,当有少量组包含大量样本时,分层将不可能,行为将接近GroupKFold。

Examples

>>> import numpy as np
>>> from sklearn.model_selection import StratifiedGroupKFold
>>> X = np.ones((17, 2))
>>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
>>> sgkf = StratifiedGroupKFold(n_splits=3)
>>> sgkf.get_n_splits(X, y)
3
>>> print(sgkf)
StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False)
>>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}")
...     print(f"         group={groups[train_index]}")
...     print(f"  Test:  index={test_index}")
...     print(f"         group={groups[test_index]}")
Fold 0:
  Train: index=[ 0  1  2  3  7  8  9 10 11 15 16]
         group=[1 1 2 2 4 5 5 5 5 8 8]
  Test:  index=[ 4  5  6 12 13 14]
         group=[3 3 3 6 6 7]
Fold 1:
  Train: index=[ 4  5  6  7  8  9 10 11 12 13 14]
         group=[3 3 3 4 5 5 5 5 6 6 7]
  Test:  index=[ 0  1  2  3 15 16]
         group=[1 1 2 2 8 8]
Fold 2:
  Train: index=[ 0  1  2  3  4  5  6 12 13 14 15 16]
         group=[1 1 2 2 3 3 3 6 6 7 8 8]
  Test:  index=[ 7  8  9 10 11]
         group=[4 5 5 5 5]
get_metadata_routing()#

获取此对象的元数据路由。

请查看 用户指南 以了解路由机制的工作原理。

Returns:
routingMetadataRequest

MetadataRequest 封装的 路由信息。

get_n_splits(X=None, y=None, groups=None)#

返回交叉验证器中的分割迭代次数。

Parameters:
Xobject

总是被忽略,存在是为了兼容性。

yobject

总是被忽略,存在是为了兼容性。

groupsobject

总是被忽略,存在是为了兼容性。

Returns:
n_splitsint

返回交叉验证器中的分割迭代次数。

set_split_request(*, groups: bool | None | str = '$UNCHANGED$') StratifiedGroupKFold#

Request metadata passed to the split method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config ). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True : metadata is requested, and passed to split if provided. The request is ignored if metadata is not provided.

  • False : metadata is not requested and the meta-estimator will not pass it to split .

  • None : metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str : metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default ( sklearn.utils.metadata_routing.UNCHANGED ) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline . Otherwise it has no effect.

Parameters:
groupsstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for groups parameter in split .

Returns:
selfobject

The updated object.

split(X, y=None, groups=None)#

生成用于将数据分割为训练集和测试集的索引。

Parameters:
X形状为 (n_samples, n_features) 的类数组

训练数据,其中 n_samples 是样本数量 且 n_features 是特征数量。

y形状为 (n_samples,) 的类数组,默认=None

监督学习问题的目标变量。

groups形状为 (n_samples,) 的类数组,默认=None

在将数据集分割为训练/测试集时使用的样本组标签。

Yields:
trainndarray

该分割的训练集索引。

testndarray

该分割的测试集索引。