Wavelet denoising is a powerful technique used to clean noisy signals, and PyWavelets offers an efficient implementation of the discrete wavelet transform (DWT). In this article, we will explore the process of wavelet denoising using PyWavelets, assuming a fundamental understanding of wavelet transforms.
The problem
We start with a signal distorted by white noise. Our goal is to recover the original signal. Consider the following example where we generate a Doppler signal and add noise to it:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pywt
def doppler(freqs, dt, amp_inc=10, t0=0, f0=np.pi*2):
t = np.arange(len(freqs)) * dt + t0
amp = np.linspace(1, np.sqrt(amp_inc), len(freqs))**2
sig = amp * np.sin(freqs * f0 * t)
return t, sig
def noisify(sig, noise_amp=1):
return sig + (np.random.random(len(sig))-0.5)*2*noise_amp
t_dop, sig_dop = doppler(np.arange(10, 20, 0.01)[::-1], 0.002)
sig_dop_n2 = noisify(sig_dop, noise_amp=2)
plt.figure(figsize=(16, 4))
plt.subplot(121)
plt.plot(t_dop, sig_dop)
plt.title("Original Signal")
plt.subplot(122)
plt.plot(t_dop, sig_dop_n2)
plt.title("Noisy Signal")
plt.show()
Our objective is to transform the noisy signal back to its original form.
Fourier transform approach
One way to approach this is by using Fourier Transform denoising, which filters out frequencies outside the known range of the original signal:
def fourier_denoising(sig, min_freq, max_freq, dt=1.0):
trans = np.fft.fft(sig)
freqs = np.fft.fftfreq(len(sig), d=dt)
trans[np.where(np.logical_or(np.abs(freqs) < min_freq, np.abs(freqs) > max_freq))] = 0
res = np.fft.ifft(trans)
return res.real
fsig_dop = np.abs(np.fft.fft(sig_dop))
fsig_dop_n2 = np.abs(np.fft.fft(sig_dop_n2))
freqs_dop = np.fft.fftfreq(len(sig_dop), d=0.002)
idx = np.where(np.abs(freqs_dop) < 50)
plt.plot(freqs_dop[idx], fsig_dop[idx], label="FFT(signal)")
plt.plot(freqs_dop[idx], fsig_dop_n2[idx], label="FFT(signal + noise)")
plt.legend(loc="best")
plt.show()
fsig_dop_fden = fourier_denoising(sig_dop_n2, 0, 20, dt=0.002)
plt.plot(t_dop, sig_dop, lw=6, alpha=0.3, label="Original Signal")
plt.plot(t_dop, fsig_dop_fden, "r-", label="Denoising Result")
plt.legend(loc="best")
plt.show()
This approach assumes knowledge of the signal's frequency range, but it may not address noise within that range.
Introduction to PyWavelets
PyWavelets provides a simple way to perform the discrete wavelet transform (DWT). Let's start by exploring available wavelet families and individual wavelets:
print(pywt.families())
print(pywt.wavelist("sym"))
We will use the Haar and Symlet wavelets for our demonstration:
haar = pywt.Wavelet("haar")
print("Haar wavelet highpass filter:", haar.filter_bank[0])
print("Haar wavelet lowpass filter:", haar.filter_bank[1])
print(haar)
sym12 = pywt.Wavelet("sym12")
phi_s12, psi_s12, x_s12 = sym12.wavefun(8)
plt.figure(figsize=(16, 4))
plt.subplot(121)
plt.title("$\\\\phi$")
plt.plot(x_s12, phi_s12)
plt.subplot(122)
plt.title("$\\\\psi$")
plt.plot(x_s12, psi_s12)
plt.show()
Performing DWT with PyWavelets
We decompose the noisy signal using the DWT and examine the approximation and detail coefficients:
cA, cD = pywt.dwt(sig_dop_n2, "sym12", mode="zero")
plt.plot(cA, label="Approximation Coefficients")
plt.plot(cD, label="Detail Coefficients")
plt.legend(loc="best")
plt.show()
To analyze the full decomposition:
coeffs = pywt.wavedec(sig_dop, "sym12")
coeffs_n = pywt.wavedec(sig_dop_n2, "sym12")
approx, details = coeffs[0], coeffs[1:]
approx_n, details_n = coeffs_n[0], coeffs_n[1:]
def plot_dwt(details, approx, xlim=(-300, 300), **line_kwargs):
for i in range(len(details)):
plt.subplot(len(details) + 1, 1, i + 1)
d = details[len(details) - 1 - i]
half = len(d) // 2
xvals = np.arange(-half, -half + len(d)) * 2**i
plt.plot(xvals, d, **line_kwargs)
plt.xlim(xlim)
plt.title("detail[{}]".format(i))
plt.subplot(len(details) + 1, 1, len(details) + 1)
plt.title("approx")
plt.plot(xvals, approx, **line_kwargs)
plt.xlim(xlim)
plt.figure(figsize=(15, 24))
plot_dwt(details, approx)
plot_dwt(details_n, approx_n, color="red", alpha=0.5)
plt.show()
Wavelet denoising with NeighBlock
We use the NeighBlock method for denoising by leveraging the correlation between neighboring coefficients:
def neigh_block(details, n, sigma):
res = []
L0 = int(np.log2(n) // 2)
L1 = max(1, L0 // 2)
L = L0 + 2 * L1
def nb_beta(sigma, L, detail):
S2 = np.sum(detail ** 2)
lmbd = 4.50524 # solution of lmbd - log(lmbd) = 3
beta = (1 - lmbd * L * sigma**2 / S2)
return max(0, beta)
for d in details:
d2 = d.copy()
for start_b in range(0, len(d2), L0):
end_b = min(len(d2), start_b + L0)
start_B = start_b - L1
end_B = start_B + L
if start_B < 0:
end_B -= start_B
start_B = 0
elif end_B > len(d2):
start_B -= end_B - len(d2)
end_B = len(d2)
assert end_B - start_B == L
d2[start_b:end_b] *= nb_beta(sigma, L, d2[start_B:end_B])
res.append(d2)
return res
details_nb = neigh_block(details_n, len(sig_dop), 0.8)
plt.figure(figsize=(15, 24))
plot_dwt(details, approx)
plot_dwt(details_n, approx_n, color="red", alpha=0.5)
plot_dwt(details_nb, approx_n, color="green", alpha=0.5, lw=2)
plt.show()
Finally, reconstruct the signal:
sig_dop_dn = pywt.waverec([approx_n] + details_nb, "sym12")
plt.figure(figsize=(15, 4))
plt.title("Denoised Signal vs Original Signal")
plt.plot(sig_dop, label="Original Signal")
plt.plot(sig_dop_dn, label="Denoised Signal")
plt.legend()
plt.show()
Conclusion
Wavelet denoising using PyWavelets effectively removes noise while preserving signal features, making it a robust technique for signal processing. If you encounter any issues, please get in touch with our support. Happy coding in Deepnote!