Transformer For Denoising
Although small datasets make denoising a challenging task, this experiment shows how transformer models can handle audio cleaning in a practical way. The idea is simple: can we teach a transformer — originally built for language — to remove noise from speech recordings?
What’s Happening Inside
- Custom Dataset: The code loads pairs of noisy and clean audio files, converts them into log-magnitude spectrograms (frequency–time representations), normalizes them, and prepares them for supervised learning.
- Transformer Denoiser: A lightweight transformer encoder maps noisy spectrograms → clean ones. Each time frame acts like a “token,” and attention helps the model focus on important spectral regions.
- Loss and Optimization: The network minimizes mean squared error (MSE) between predicted and clean spectrograms. Gradient clipping and a slightly lower learning rate stabilize training and prevent spikes.
Once trained, the model predicts a clean magnitude spectrogram from a noisy waveform and reconstructs the waveform using Griffin-Lim phase estimation. The result is a denoised audio file ready to play. The script also visualizes:
- Noisy vs. denoised waveforms
- Their differences
- Spectrograms for clearer comparison
This version is a good example, thanks to log-magnitude spectrograms, normalization, and gradient clipping. For future experiments, one could explore positional embeddings, learnable embeddings, deeper transformer layers, or local attention mechanisms to capture temporal patterns even better.
Dataset used:
University of Edinburgh DataShare – VoiceBank-DEMAND corpus
(I used smaller clean and noisy test subsets for my experiments.)
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
import os
# ---- Dataset Definition ----
class SpeechDenoiseDataset(Dataset):
def __init__(self, noisy_dir, clean_dir, n_fft=512, hop_length=128):
self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)])
self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)])
self.n_fft = n_fft
self.hop_length = hop_length
self.spectrogram = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
def __len__(self):
return len(self.noisy_files)
def __getitem__(self, idx):
noisy_wave, sr = torchaudio.load(self.noisy_files[idx])
clean_wave, _ = torchaudio.load(self.clean_files[idx])
# Mono
noisy_wave = noisy_wave.mean(dim=0, keepdim=True)
clean_wave = clean_wave.mean(dim=0, keepdim=True)
# Convert to spectrogram
noisy_spec_complex = self.spectrogram(noisy_wave)
clean_spec_complex = self.spectrogram(clean_wave)
# Log-magnitude and normalization
noisy_spec = torch.log1p(noisy_spec_complex.abs())
clean_spec = torch.log1p(clean_spec_complex.abs())
noisy_spec = (noisy_spec - noisy_spec.mean()) / (noisy_spec.std() + 1e-6)
clean_spec = (clean_spec - clean_spec.mean()) / (clean_spec.std() + 1e-6)
# transpose to (time, freq)
return noisy_spec.squeeze(0).transpose(0, 1), clean_spec.squeeze(0).transpose(0, 1)
# ---- Transformer Denoiser ----
class TransformerDenoiser(nn.Module):
def __init__(self, n_freq, n_heads=4, n_layers=2, dim_feedforward=256):
super().__init__()
d_model = (n_freq // n_heads) * n_heads
self.input_proj = nn.Linear(n_freq, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
dim_feedforward=dim_feedforward,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.output_proj = nn.Linear(d_model, n_freq)
def forward(self, x):
x = self.input_proj(x)
x = self.transformer(x)
x = self.output_proj(x)
return x
# ---- Hyperparameters ----
clean_dir = "cleanTrainset16k"
noisy_dir = "noisyTrainset16k"
n_fft = 512
hop_length = 128
batch_size = 1
epochs = 20
lr = 5e-4 # slightly lower for stability
# ---- DataLoader ----
dataset = SpeechDenoiseDataset(noisy_dir, clean_dir, n_fft, hop_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# ---- Model, Loss, Optimizer ----
example_spec, _ = dataset[0]
n_freq = example_spec.shape[1]
model = TransformerDenoiser(n_freq)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# ---- Training Loop ----
for epoch in range(epochs):
epoch_loss = 0
for noisy_spec, clean_spec in dataloader:
noisy_spec = noisy_spec.to(device)
clean_spec = clean_spec.to(device)
optimizer.zero_grad()
output = model(noisy_spec)
loss = criterion(output, clean_spec)
loss.backward()
# gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs} - Loss: {epoch_loss / len(dataloader):.4f}")
# ---- Inference ----
test_noisy_file = "testMonoAudio.wav"
waveform, sr = torchaudio.load(test_noisy_file)
waveform = waveform.mean(dim=0, keepdim=True)
spectrogram = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
spec_complex = spectrogram(waveform)
mag_spec = torch.log1p(spec_complex.abs())
mag_spec = (mag_spec - mag_spec.mean()) / (mag_spec.std() + 1e-6)
input_spec = mag_spec.squeeze(0).transpose(0, 1).unsqueeze(0).to(device)
with torch.no_grad():
denoised_mag = model(input_spec)
denoised_mag = denoised_mag.squeeze(0).transpose(0, 1)
denoised_mag = torch.relu(denoised_mag)
griffin_lim = T.GriffinLim(n_fft=n_fft, hop_length=hop_length)
denoised_waveform = griffin_lim(torch.expm1(denoised_mag.cpu())) # invert log1p
torchaudio.save("denoised_audio.wav", denoised_waveform.unsqueeze(0), sr)
print("Denoised audio saved as denoised_audio.wav")[1] Attention Is All Your Need
[2] A Review of Deep Learning Techniques for Speech Processing
[3] University of Edinburgh DataShare – VoiceBank-DEMAND corpus