Simple RNN for Speech Denoising

I recently built a lightweight speech denoiser using a GRU-based recurrent neural network that operates directly on raw audio frames.

What I did:

  • Implemented a SimpleDenoiserRNN with a single GRU layer and a linear output to reconstruct clean audio from noisy inputs.

  • Processed audio in sequences of multiple frames to leverage temporal context effectively — making the GRU truly recurrent, not just frame-by-frame.

  • Added overlap-add with Hanning windowing during inference to smooth transitions and eliminate block artifacts.

  • Worked with downsampled 16 kHz audio to balance computational cost and speech intelligibility.

Key takeaways:

  • Proper data shaping for RNNs — especially feeding sequences rather than isolated frames — is crucial for capturing temporal dependencies.

  • Downsampling the audio dataset from 48 kHz to 16 kHz, while sacrificing some quality, is a practical trade-off for training efficiency and is a solid standard for speech tasks.

  • Overlap-add reconstruction improves audio quality by removing “hard edges” between processed chunks.

  • This simple GRU denoiser is mostly effective on mild, stationary noise such as consistent hums, fan noise, or light hiss.

You can modify the model architecture, experiment with different datasets, and tune parameters to achieve better results.

Dataset used:
University of Edinburgh DataShare – VoiceBank-DEMAND corpus
(I used smaller clean and noisy test subsets for my experiments.)

Note on computation:

This model is designed for CPU-only execution to keep things simple and accessible for everyone. While GPU support can speed up training and inference, this project prioritizes ease of use and reproducibility without special hardware requirements.

import os
import numpy as np
import torch
import torch.nn as nn
import soundfile as sf
import matplotlib.pyplot as plt

# --- Model ---
class SimpleDenoiserRNN(nn.Module):
    def __init__(self, input_size=128, hidden_size=128):
        super().__init__()
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out)
        return out

# --- Sequence framing with overlap ---
def create_overlapping_sequences(audio, frame_size, sequence_length, hop_length):
    # Break audio into frames
    n_frames = len(audio) // frame_size
    audio = audio[:n_frames * frame_size]
    frames = audio.reshape(n_frames, frame_size)

    sequences = []
    for start in range(0, n_frames - sequence_length + 1, hop_length):
        seq = frames[start:start + sequence_length]
        sequences.append(seq)
    return np.array(sequences)  # (n_sequences, seq_len, frame_size)

# --- Overlap-add reconstruction ---
def reconstruct_with_overlap(sequences, frame_size, hop_length):
    sequence_length = sequences.shape[1]
    n_sequences = sequences.shape[0]
    total_frames = hop_length * (n_sequences - 1) + sequence_length

    output = np.zeros((total_frames, frame_size))
    window = np.hanning(sequence_length * frame_size).reshape(sequence_length, frame_size)

    frame_counts = np.zeros((total_frames, frame_size))

    for i in range(n_sequences):
        start_frame = i * hop_length
        output[start_frame:start_frame + sequence_length] += sequences[i] * window
        frame_counts[start_frame:start_frame + sequence_length] += window

    # Avoid division by zero
    frame_counts[frame_counts == 0] = 1.0
    output /= frame_counts

    return output.reshape(-1)

# --- Load dataset with overlap ---
def load_paired_sequences(clean_dir, noisy_dir, frame_size, sequence_length, hop_length):
    clean_files = sorted([f for f in os.listdir(clean_dir) if f.endswith('.wav')])
    noisy_files = sorted([f for f in os.listdir(noisy_dir) if f.endswith('.wav')])
    assert len(clean_files) == len(noisy_files), "Mismatch in file counts"

    clean_seq_list = []
    noisy_seq_list = []

    for cfile, nfile in zip(clean_files, noisy_files):
        clean_audio, sr = sf.read(os.path.join(clean_dir, cfile))
        noisy_audio, sr = sf.read(os.path.join(noisy_dir, nfile))

        clean_seqs = create_overlapping_sequences(clean_audio, frame_size, sequence_length, hop_length)
        noisy_seqs = create_overlapping_sequences(noisy_audio, frame_size, sequence_length, hop_length)

        clean_seq_list.append(clean_seqs)
        noisy_seq_list.append(noisy_seqs)

    clean_all = np.vstack(clean_seq_list)
    noisy_all = np.vstack(noisy_seq_list)

    clean_tensor = torch.from_numpy(clean_all).float()
    noisy_tensor = torch.from_numpy(noisy_all).float()

    return noisy_tensor, clean_tensor

# --- Main ---
def main():
    clean_dir = 'cleanTrainset16k'
    noisy_dir = 'noisyTrainset16k'
    test_noisy_path = 'yourMonoAudio16k.wav'

    frame_size = 128
    sequence_length = 10
    hop_length = 5   # 50% overlap
    batch_size = 32
    epochs = 4

    print("Loading dataset...")
    noisy_data, clean_data = load_paired_sequences(clean_dir, noisy_dir, frame_size, sequence_length, hop_length)
    print(f"Total sequences loaded: {noisy_data.shape[0]} (each {sequence_length} frames, hop {hop_length})")

    model = SimpleDenoiserRNN(input_size=frame_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    # --- Training ---
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for i in range(0, noisy_data.size(0), batch_size):
            noisy_batch = noisy_data[i:i + batch_size]
            clean_batch = clean_data[i:i + batch_size]

            optimizer.zero_grad()
            output = model(noisy_batch)
            loss = criterion(output, clean_batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * noisy_batch.size(0)

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / noisy_data.size(0):.6f}")

    # --- Testing ---
    print("Denoising test audio...")
    noisy_audio, sr_test = sf.read(test_noisy_path)
    noisy_sequences = create_overlapping_sequences(noisy_audio, frame_size, sequence_length, hop_length)
    frames_tensor = torch.from_numpy(noisy_sequences).float()

    model.eval()
    with torch.no_grad():
        denoised_sequences = model(frames_tensor).numpy()

    enhanced_audio = reconstruct_with_overlap(denoised_sequences, frame_size, hop_length)

    sf.write('enhancedSpeech.wav', enhanced_audio, sr_test)

    # --- Plot ---
    time = np.linspace(0, len(enhanced_audio) / sr_test, len(enhanced_audio))
    plt.figure(figsize=(12, 8))

    plt.subplot(3, 1, 1)
    plt.title('Original Noisy Audio')
    plt.plot(np.linspace(0, len(noisy_audio) / sr_test, len(noisy_audio)), noisy_audio)
    plt.xlabel('Time [s]')
    plt.ylabel('Amplitude')

    plt.subplot(3, 1, 2)
    plt.title('Enhanced Audio')
    plt.plot(time, enhanced_audio)
    plt.xlabel('Time [s]')
    plt.ylabel('Amplitude')

    plt.subplot(3, 1, 3)
    plt.title('Estimated Noise (Noisy - Enhanced)')
    plt.plot(time, noisy_audio[:len(enhanced_audio)] - enhanced_audio)
    plt.xlabel('Time [s]')
    plt.ylabel('Amplitude')

    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    main()
[1] Empirical evaluation of gated recurrent neural networks on sequence modelling

[2] A Review of Deep Learning Techniques for Speech Processing

[3] A Friendly Introduction to Recurrent Neural Networks

[4] University of Edinburgh DataShare – VoiceBank-DEMAND corpus