speechbrain.processing.NMF 模块
非负矩阵分解
- Authors
Cem Subakan
摘要
函数:
此函数在给定NMF模板矩阵的情况下分离混合信号。 |
|
此函数将分离的光谱重建为波形。 |
|
返回复数频谱图的相位。 |
参考
- speechbrain.processing.NMF.spectral_phase(stft)[source]
返回复数频谱图的相位。
- Parameters:
stft (torch.Tensor) – 一个张量,来自stft函数的输出。
- Returns:
阶段
- Return type:
torch.Tensor
Example
>>> BS, nfft, T = 10, 20, 300 >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2) >>> phase_mix = spectral_phase(X_stft)
- speechbrain.processing.NMF.NMF_separate_spectra(Whats, Xmix)[source]
此函数根据NMF模板矩阵分离混合信号。
- Parameters:
Whats (list) – 此列表包含列表 [W1, W2],其中 W1 和 W2 分别是 对应于 source1 和 source2 的 NMF 模板矩阵。 W1 和 W2 的大小为 [nfft/2 + 1, K],其中 nfft 是 STFT 的 fft 大小, K 是 W 中的向量(模板)数量。
Xmix (torch.Tensor) – 这是混合物的幅度谱。 大小为 [BS x T x nfft//2 + 1],其中, BS = 批量大小,nfft = fft 大小,T = 频谱中的时间步数。
- Returns:
X1hat (源1的分离频谱) – 大小 = [BS x (nfft/2 +1) x T],其中, BS = 批量大小,nfft = fft 大小,T = 频谱中的时间步数。
X2hat (源2的分离频谱) – 大小定义与上述相同。
Example
>>> BS, nfft, T = 4, 20, 400 >>> K1, K2 = 10, 10 >>> W1hat = torch.randn(nfft//2 + 1, K1) >>> W2hat = torch.randn(nfft//2 + 1, K2) >>> Whats = [W1hat, W2hat] >>> Xmix = torch.randn(BS, T, nfft//2 + 1) >>> X1hat, X2hat = NMF_separate_spectra(Whats, Xmix)
- speechbrain.processing.NMF.reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length)[source]
此函数将分离的光谱重建为波形。
- Parameters:
X1hat (torch.Tensor) – 源1的分离频谱,大小为[BS, nfft/2 + 1, T], 其中,BS = 批量大小,nfft = fft大小,T = 频谱的长度。
X2hat (torch.Tensor) – 源2的分离频谱,大小为 [BS, nfft/2 + 1, T]。 大小定义与 Xhat1 相同。
X_stft (torch.Tensor) – 这是混合物的幅度谱。 大小为 [BS x nfft//2 + 1 x T x 2],其中, BS = 批量大小,nfft = fft 大小,T = 频谱中的时间步数。 最后一个维度用于表示复数。
sample_rate (int) – 我们希望保存结果的采样率(以Hz为单位)。
win_length (int) – STFT窗口的长度(以毫秒为单位)。
hop_length (int) – 我们移动STFT窗口的长度(以毫秒为单位)。
- Returns:
x1hats (列表) – 源1的波形列表。
x2hats (列表) – 源2的波形列表。
Example
>>> BS, nfft, T = 10, 512, 16000 >>> sample_rate, win_length, hop_length = 16000, 25, 10 >>> X1hat = torch.randn(BS, nfft//2 + 1, T) >>> X2hat = torch.randn(BS, nfft//2 + 1, T) >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2) >>> x1hats, x2hats = reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length)