使用 J-不变性 校准降噪器的完整教程#

在这个例子中,我们展示了如何找到任何去噪算法的最佳校准版本。

校准方法基于 [1]noise2self 算法。

参见

该方法的一个简单示例在 使用 J-不变性校准去噪器 中给出。

校准小波降噪器

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec

from skimage.data import chelsea, hubble_deep_field
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.restoration import (
    calibrate_denoiser,
    denoise_wavelet,
    denoise_tv_chambolle,
    denoise_nl_means,
    estimate_sigma,
)
from skimage.util import img_as_float, random_noise
from skimage.color import rgb2gray
from functools import partial

_denoise_wavelet = partial(denoise_wavelet, rescale_sigma=True)

image = img_as_float(chelsea())
sigma = 0.2
noisy = random_noise(image, var=sigma**2)

# Parameters to test when calibrating the denoising algorithm
parameter_ranges = {
    'sigma': np.arange(0.1, 0.3, 0.02),
    'wavelet': ['db1', 'db2'],
    'convert2ycbcr': [True, False],
    'channel_axis': [-1],
}

# Denoised image using default parameters of `denoise_wavelet`
default_output = denoise_wavelet(noisy, channel_axis=-1, rescale_sigma=True)

# Calibrate denoiser
calibrated_denoiser = calibrate_denoiser(
    noisy, _denoise_wavelet, denoise_parameters=parameter_ranges
)

# Denoised image using calibrated denoiser
calibrated_output = calibrated_denoiser(noisy)

fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(15, 5))

for ax, img, title in zip(
    axes,
    [noisy, default_output, calibrated_output],
    ['Noisy Image', 'Denoised (Default)', 'Denoised (Calibrated)'],
):
    ax.imshow(img)
    ax.set_title(title)
    ax.set_yticks([])
    ax.set_xticks([])
Noisy Image, Denoised (Default), Denoised (Calibrated)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.13240256245502907..0.932841463326229].

自监督损失和J不变性#

这种校准方法的关键在于 J-不变性 的概念。如果一个去噪函数对每个像素的预测不依赖于该像素在原始图像中的值,那么它就是 J-不变的。每个像素的预测可以使用图像中其余部分包含的所有相关信息,这通常是非常重要的。任何函数都可以通过简单的掩码过程转换为 J-不变的函数,如 [1] 中所述。

J-不变去噪器的像素级误差与噪声无关,只要每个像素的噪声是独立的。因此,去噪图像与噪声图像之间的平均差异,即 自监督损失,与去噪图像与原始干净图像之间的差异,即 *真实损失*(至多相差一个常数)相同。

这意味着对于给定的图像,可以通过选择最小化自监督损失的去噪器,仅使用噪声数据来找到最佳的J-不变去噪器。下面,我们为一系列具有不同`sigma`参数的小波去噪器演示了这一点。自监督损失(实蓝线)和真实损失(虚蓝线)具有相同的形状和相同的极小值。

from skimage.restoration import denoise_invariant

sigma_range = np.arange(sigma / 2, 1.5 * sigma, 0.025)

parameters_tested = [
    {'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2', 'channel_axis': -1}
    for sigma in sigma_range
]

denoised_invariant = [
    denoise_invariant(noisy, _denoise_wavelet, denoiser_kwargs=params)
    for params in parameters_tested
]

self_supervised_loss = [mse(img, noisy) for img in denoised_invariant]
ground_truth_loss = [mse(img, image) for img in denoised_invariant]

opt_idx = np.argmin(self_supervised_loss)
plot_idx = [0, opt_idx, len(sigma_range) - 1]


def get_inset(x):
    return x[25:225, 100:300]


plt.figure(figsize=(10, 12))

gs = gridspec.GridSpec(3, 3)
ax1 = plt.subplot(gs[0, :])
ax2 = plt.subplot(gs[1, :])
ax_image = [plt.subplot(gs[2, i]) for i in range(3)]

ax1.plot(sigma_range, self_supervised_loss, color='C0', label='Self-Supervised Loss')
ax1.scatter(
    sigma_range[opt_idx],
    self_supervised_loss[opt_idx] + 0.0003,
    marker='v',
    color='red',
    label='optimal sigma',
)

ax1.set_ylabel('MSE')
ax1.set_xticks([])
ax1.legend()
ax1.set_title('Self-Supervised Loss')

ax2.plot(
    sigma_range,
    ground_truth_loss,
    color='C0',
    linestyle='--',
    label='Ground Truth Loss',
)
ax2.scatter(
    sigma_range[opt_idx],
    ground_truth_loss[opt_idx] + 0.0003,
    marker='v',
    color='red',
    label='optimal sigma',
)
ax2.set_ylabel('MSE')
ax2.legend()
ax2.set_xlabel('sigma')
ax2.set_title('Ground-Truth Loss')

for i in range(3):
    ax = ax_image[i]
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(get_inset(denoised_invariant[plot_idx[i]]))
    ax.set_xlabel('sigma = ' + str(np.round(sigma_range[plot_idx[i]], 2)))

for spine in ax_image[1].spines.values():
    spine.set_edgecolor('red')
    spine.set_linewidth(5)
Self-Supervised Loss, Ground-Truth Loss
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0466469856470723..0.9921529309987775].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.13240256245502913..0.9328414633262291].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.059235880117713725..0.8926423123991685].

转换为 J-不变量#

函数 _invariant_denoise 充当给定去噪器的 J-不变版本。它通过屏蔽部分像素、插值、运行原始去噪器并提取屏蔽像素中返回的值来工作。对图像进行迭代会产生完全 J-不变的输出。

对于任何给定的一组参数,去噪器的J-不变版本与原始去噪器不同,但不一定更好或更差。在下图中,我们看到,对于猫的测试图像,在小方差减少 sigma 值时,小波去噪器的J-不变版本明显优于原始版本,而在较大值时则几乎无差别。

parameters_tested = [
    {'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2', 'channel_axis': -1}
    for sigma in sigma_range
]

denoised_original = [_denoise_wavelet(noisy, **params) for params in parameters_tested]

ground_truth_loss_invariant = [mse(img, image) for img in denoised_invariant]
ground_truth_loss_original = [mse(img, image) for img in denoised_original]

fig, ax = plt.subplots(figsize=(10, 4))

ax.plot(
    sigma_range,
    ground_truth_loss_invariant,
    color='C0',
    linestyle='--',
    label='J-invariant',
)
ax.plot(
    sigma_range,
    ground_truth_loss_original,
    color='C1',
    linestyle='--',
    label='Original',
)
ax.scatter(
    sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.001, marker='v', color='red'
)
ax.legend()
ax.set_title(
    'J-Invariant Denoiser Has Comparable Or ' 'Better Performance At Same Parameters'
)
ax.set_ylabel('MSE')
ax.set_xlabel('sigma')
J-Invariant Denoiser Has Comparable Or Better Performance At Same Parameters
Text(0.5, 14.722222222222216, 'sigma')

比较不同类别的降噪器#

自监督损失不仅可以用于比较不同类别的去噪器,还可以为单个类别选择参数。这使得用户能够以无偏的方式,为给定图像选择最适合的去噪器类别的最佳参数。

下面,我们展示了一个添加了显著散斑噪声的哈勃深场图像的例子。在这种情况下,J-不变校准去噪器在三种去噪器家族——非局部均值、小波和TV范数中,都优于默认去噪器。此外,自监督损失显示,TV范数去噪器是处理这种噪声图像的最佳选择。

image = rgb2gray(img_as_float(hubble_deep_field()[100:250, 50:300]))

sigma = 0.4
noisy = random_noise(image, mode='speckle', var=sigma**2)

parameter_ranges_tv = {'weight': np.arange(0.01, 0.3, 0.02)}
_, (parameters_tested_tv, losses_tv) = calibrate_denoiser(
    noisy,
    denoise_tv_chambolle,
    denoise_parameters=parameter_ranges_tv,
    extra_output=True,
)
print(f'Minimum self-supervised loss TV: {np.min(losses_tv):.4f}')

best_parameters_tv = parameters_tested_tv[np.argmin(losses_tv)]
denoised_calibrated_tv = denoise_invariant(
    noisy, denoise_tv_chambolle, denoiser_kwargs=best_parameters_tv
)
denoised_default_tv = denoise_tv_chambolle(noisy, **best_parameters_tv)

psnr_calibrated_tv = psnr(image, denoised_calibrated_tv)
psnr_default_tv = psnr(image, denoised_default_tv)

parameter_ranges_wavelet = {'sigma': np.arange(0.01, 0.3, 0.03)}
_, (parameters_tested_wavelet, losses_wavelet) = calibrate_denoiser(
    noisy, _denoise_wavelet, parameter_ranges_wavelet, extra_output=True
)
print(f'Minimum self-supervised loss wavelet: {np.min(losses_wavelet):.4f}')

best_parameters_wavelet = parameters_tested_wavelet[np.argmin(losses_wavelet)]
denoised_calibrated_wavelet = denoise_invariant(
    noisy, _denoise_wavelet, denoiser_kwargs=best_parameters_wavelet
)
denoised_default_wavelet = _denoise_wavelet(noisy, **best_parameters_wavelet)

psnr_calibrated_wavelet = psnr(image, denoised_calibrated_wavelet)
psnr_default_wavelet = psnr(image, denoised_default_wavelet)

sigma_est = estimate_sigma(noisy)

parameter_ranges_nl = {
    'sigma': np.arange(0.6, 1.4, 0.2) * sigma_est,
    'h': np.arange(0.6, 1.2, 0.2) * sigma_est,
}

parameter_ranges_nl = {'sigma': np.arange(0.01, 0.3, 0.03)}
_, (parameters_tested_nl, losses_nl) = calibrate_denoiser(
    noisy, denoise_nl_means, parameter_ranges_nl, extra_output=True
)
print(f'Minimum self-supervised loss NL means: {np.min(losses_nl):.4f}')

best_parameters_nl = parameters_tested_nl[np.argmin(losses_nl)]
denoised_calibrated_nl = denoise_invariant(
    noisy, denoise_nl_means, denoiser_kwargs=best_parameters_nl
)
denoised_default_nl = denoise_nl_means(noisy, **best_parameters_nl)

psnr_calibrated_nl = psnr(image, denoised_calibrated_nl)
psnr_default_nl = psnr(image, denoised_default_nl)

print('                       PSNR')
print(f'NL means (Default)   : {psnr_default_nl:.1f}')
print(f'NL means (Calibrated): {psnr_calibrated_nl:.1f}')
print(f'Wavelet  (Default)   : {psnr_default_wavelet:.1f}')
print(f'Wavelet  (Calibrated): {psnr_calibrated_wavelet:.1f}')
print(f'TV norm  (Default)   : {psnr_default_tv:.1f}')
print(f'TV norm  (Calibrated): {psnr_calibrated_tv:.1f}')

plt.subplots(figsize=(10, 12))
plt.imshow(noisy, cmap='Greys_r')
plt.xticks([])
plt.yticks([])
plt.title('Noisy Image')


def get_inset(x):
    return x[0:100, -140:]


fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(15, 8))

for ax in axes.ravel():
    ax.set_xticks([])
    ax.set_yticks([])

axes[0, 0].imshow(get_inset(denoised_default_nl), cmap='Greys_r')
axes[0, 0].set_title('NL Means Default')
axes[1, 0].imshow(get_inset(denoised_calibrated_nl), cmap='Greys_r')
axes[1, 0].set_title('NL Means Calibrated')
axes[0, 1].imshow(get_inset(denoised_default_wavelet), cmap='Greys_r')
axes[0, 1].set_title('Wavelet Default')
axes[1, 1].imshow(get_inset(denoised_calibrated_wavelet), cmap='Greys_r')
axes[1, 1].set_title('Wavelet Calibrated')
axes[0, 2].imshow(get_inset(denoised_default_tv), cmap='Greys_r')
axes[0, 2].set_title('TV Norm Default')
axes[1, 2].imshow(get_inset(denoised_calibrated_tv), cmap='Greys_r')
axes[1, 2].set_title('TV Norm Calibrated')

for spine in axes[1, 2].spines.values():
    spine.set_edgecolor('red')
    spine.set_linewidth(5)

plt.show()
  • Noisy Image
  • NL Means Default, Wavelet Default, TV Norm Default, NL Means Calibrated, Wavelet Calibrated, TV Norm Calibrated
Minimum self-supervised loss TV: 0.0033
Minimum self-supervised loss wavelet: 0.0033
Minimum self-supervised loss NL means: 0.0036
                       PSNR
NL means (Default)   : 25.9
NL means (Calibrated): 27.5
Wavelet  (Default)   : 26.1
Wavelet  (Calibrated): 29.3
TV norm  (Default)   : 28.3
TV norm  (Calibrated): 29.6

脚本总运行时间: (0 分钟 6.866 秒)

由 Sphinx-Gallery 生成的图库