balanced_batch_generator#
- imblearn.keras.balanced_batch_generator(X, y, *, sample_weight=None, sampler=None, batch_size=32, keep_sparse=False, random_state=None)[source]#
创建一个平衡的批次生成器来训练keras模型。
返回一个生成器——以及每个周期的步数——它被提供给
fit
。采样器定义了在创建批次之前用于平衡数据集的采样策略。采样器应该有一个属性sample_indices_
。- Parameters:
- Xndarray of shape (n_samples, n_features)
原始不平衡数据集。
- yndarray of shape (n_samples,) or (n_samples, n_classes)
关联目标。
- sample_weightndarray of shape (n_samples,), default=None
样本权重。
- samplersampler object, default=None
一个具有属性
sample_indices_
的采样器实例。 默认情况下,使用的采样器是RandomUnderSampler
。- batch_sizeint, default=32
每次梯度更新的样本数量。
- keep_sparsebool, default=False
是否保留输入的稀疏性(即
X
,y
,sample_weight
)。默认情况下,返回的批次将是密集的。- random_stateint, RandomState instance, default=None
控制算法的随机化。
如果是整数,
random_state
是随机数生成器使用的种子;如果
RandomState
实例,random_state 是随机数生成器;如果
None
,随机数生成器是np.random
使用的RandomState
实例。
- Returns:
- generatorgenerator of tuple
生成一批数据。生成的元组可能是 (X_batch, y_batch) 或 (X_batch, y_batch, sampler_weight_batch)。
- steps_per_epochint
每个时期的样本数量。由keras中的
fit_generator
所需。
示例
>>> from sklearn.datasets import load_iris >>> X, y = load_iris(return_X_y=True) >>> from imblearn.datasets import make_imbalance >>> class_dict = dict() >>> class_dict[0] = 30; class_dict[1] = 50; class_dict[2] = 40 >>> from imblearn.datasets import make_imbalance >>> X, y = make_imbalance(X, y, sampling_strategy=class_dict) >>> import tensorflow >>> y = tensorflow.keras.utils.to_categorical(y, 3) >>> model = tensorflow.keras.models.Sequential() >>> model.add( ... tensorflow.keras.layers.Dense( ... y.shape[1], input_dim=X.shape[1], activation='softmax' ... ) ... ) >>> model.compile(optimizer='sgd', loss='categorical_crossentropy', ... metrics=['accuracy']) >>> from imblearn.keras import balanced_batch_generator >>> from imblearn.under_sampling import NearMiss >>> training_generator, steps_per_epoch = balanced_batch_generator( ... X, y, sampler=NearMiss(), batch_size=10, random_state=42) >>> callback_history = model.fit(training_generator, ... steps_per_epoch=steps_per_epoch, ... epochs=10, verbose=0)