Denoising with Rectified Flow Models
This project trains a 1D convolutional neural network with time embeddings to remove noise from audio, inspired by Rectified Flow Models (RFM).
What is a Rectified Flow Model?
An RFM learns a velocity field that gradually transforms noisy data into clean data over time. Instead of predicting the clean audio directly, the model predicts the direction and magnitude of change needed at each time step to move from noise toward the target signal.
How it works:
Conditional training: The model is trained on (noisy, clean) audio pairs, learning to map noisy inputs toward their clean counterparts.
Iterative denoising: During inference, the model updates the audio step by step using predicted velocities. A cosine-scheduled time embedding ensures smooth and stable refinement.
Blending & velocity scaling: To preserve subtle speech, the final output blends the denoised signal with a fraction of the original noisy audio, and predicted velocities can be scaled for gentler processing.
Chunked processing: Audio is processed in small chunks, avoiding artifacts and reducing memory usage.
Visualization: The demo shows the original noisy audio, the denoised output, and the difference (what noise was removed).
This approach demonstrates how RFMs can be applied to real-world audio denoising, balancing noise removal with speech preservation. It should be considered as a simple structure for a demonstration.
import os import math import torch import torch.nn as nn import torch.optim as optim import torchaudio import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader # --- Sinusoidal time embedding --- def sinusoidal_embedding(t, dim=32): device = t.device half_dim = dim // 2 emb = torch.exp(torch.arange(half_dim, device=device) * -(math.log(10000) / half_dim)) emb = t[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # --- 1D ConvNet with time embedding --- class U1DNet(nn.Module): def __init__(self, time_emb_dim=32): super().__init__() self.time_mlp = nn.Sequential( nn.Linear(time_emb_dim, 32), nn.ReLU(), nn.Linear(32, 32) ) self.conv1 = nn.Conv1d(1, 32, 9, padding=4) self.conv2 = nn.Conv1d(32, 32, 9, padding=4) self.conv3 = nn.Conv1d(32, 1, 9, padding=4) def forward(self, x, t): t_emb = sinusoidal_embedding(t, dim=32) t_emb = self.time_mlp(t_emb).unsqueeze(-1) x = self.conv1(x) x = x + t_emb x = torch.relu(x) x = torch.relu(self.conv2(x)) x = self.conv3(x) return x # --- Dataset for paired noisy-clean audio --- class PairedAudioDataset(Dataset): def __init__(self, clean_folder, noisy_folder, sr=16000, duration=None): self.clean_files = sorted(os.listdir(clean_folder)) self.noisy_files = sorted(os.listdir(noisy_folder)) assert len(self.clean_files) == len(self.noisy_files), "Mismatch in number of files!" self.clean_folder = clean_folder self.noisy_folder = noisy_folder self.sr = sr self.duration = duration def __len__(self): return len(self.clean_files) def __getitem__(self, idx): clean_path = os.path.join(self.clean_folder, self.clean_files[idx]) noisy_path = os.path.join(self.noisy_folder, self.noisy_files[idx]) x_clean, _ = torchaudio.load(clean_path) x_noisy, _ = torchaudio.load(noisy_path) x_clean = x_clean.mean(dim=0, keepdim=True) x_noisy = x_noisy.mean(dim=0, keepdim=True) if self.duration: num_samples = int(self.sr * self.duration) x_clean = x_clean[:, :num_samples] x_noisy = x_noisy[:, :num_samples] x_clean = x_clean / x_clean.abs().max() x_noisy = x_noisy / x_noisy.abs().max() return x_noisy, x_clean # --- Rectified Flow Loss for conditional denoising --- def rectified_flow_loss(model, x_noisy, x_clean): B = x_clean.size(0) t = torch.rand(B, device=x_clean.device).view(-1) # Optional small perturbation for smoothness noise = torch.randn_like(x_clean) * 0.01 xt = x_noisy + t.view(-1,1,1) * noise target_velocity = (x_clean - xt) * 0.8 pred_velocity = model(xt, t) return ((pred_velocity - target_velocity) ** 2).mean() # --- Parameters --- device = "cuda" if torch.cuda.is_available() else "cpu" sr = 16000 train_duration = 1 # seconds per slice batch_size = 8 num_epochs = 50 learning_rate = 1e-4 # --- Dataset & DataLoader --- dataset = PairedAudioDataset( clean_folder='cleanTrainset16k', noisy_folder='noisyTrainset16k', sr=sr, duration=train_duration ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # --- Model & Optimizer --- model = U1DNet(time_emb_dim=32).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) # --- Training loop --- model_path = 'rfm_conditional_model.pth' #previously saved model path if os.path.exists(model_path): # Load existing model model.load_state_dict(torch.load(model_path)) model.to(device) model.eval() print(f"Loaded existing model from {model_path}, skipping training.") else: for epoch in range(num_epochs): for x_noisy, x_clean in dataloader: x_noisy, x_clean = x_noisy.to(device), x_clean.to(device) loss = rectified_flow_loss(model, x_noisy, x_clean) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch+1}/{num_epochs}: loss={loss.item():.6f}") torch.save(model.state_dict(), model_path) print(f"Saved conditional rectified flow model to {model_path}.") # --- Denoise / generate new audio with blending and velocity scaling --- def rectified_flow_denoise(model, x_noisy, steps=50, chunk_duration=1, blend_ratio=0.1, velocity_scale=1.0): """ model: trained RFM model x_noisy: input noisy audio (Tensor [1, num_samples] or [num_samples]) steps: number of iterative steps chunk_duration: seconds per chunk blend_ratio: fraction of original noisy audio to mix at the end velocity_scale: scale factor for predicted velocity (less than 1 for gentler denoising) """ model.eval() if x_noisy.dim() == 1: x_noisy = x_noisy.unsqueeze(0) x_noisy = x_noisy.mean(dim=0, keepdim=True) x_noisy = x_noisy / x_noisy.abs().max() num_samples = x_noisy.size(1) chunk_size = int(sr * chunk_duration) output = [] with torch.no_grad(): for start in range(0, num_samples, chunk_size): end = min(start + chunk_size, num_samples) chunk = x_noisy[:, start:end] if chunk.size(1) < chunk_size: pad_len = chunk_size - chunk.size(1) chunk = torch.nn.functional.pad(chunk, (0, pad_len)) chunk = chunk.unsqueeze(0).to(device) x = chunk.clone() # start from noisy input for i in range(steps): # Smooth t schedule (cosine) t = torch.full((1,), 0.5 * (1 - math.cos(math.pi * i / steps)), device=device) velocity = model(x, t) * velocity_scale x = x + (1 / steps) * velocity x = x[:, :, :end-start] # Blend with original noisy input to preserve subtle speech x = (1 - blend_ratio) * x.cpu() + blend_ratio * x_noisy[:, start:end] output.append(x) return torch.cat(output, dim=2).squeeze() # --- Load test noisy audio --- test_path = 'testMonoAudio.wav' x_noisy_test, _ = torchaudio.load(test_path) # --- Denoise --- denoised_waveform = rectified_flow_denoise(model, x_noisy_test, steps=50, chunk_duration=train_duration, blend_ratio=0.1, velocity_scale=0.9) # --- Save denoised audio --- torchaudio.save("rfm_denoised_output.wav", denoised_waveform.unsqueeze(0), sr) # --- Save difference audio --- diff_waveform = x_noisy_test.squeeze()[:denoised_waveform.size(0)] - denoised_waveform torchaudio.save("rfm_difference.wav", diff_waveform.unsqueeze(0), sr) # --- Plot --- plt.figure(figsize=(10, 9)) # Original Noisy Audio plt.subplot(3, 1, 1) plt.plot(x_noisy_test.squeeze().numpy(), color='r') plt.title("Original Noisy Audio") plt.xlabel("Sample Index") plt.ylabel("Amplitude") # Conditional RFM Denoised Output plt.subplot(3, 1, 2) plt.plot(denoised_waveform.numpy(), color='b') plt.title("Conditional RFM Denoised Output") plt.xlabel("Sample Index") plt.ylabel("Amplitude") # Difference (Noisy - Denoised) plt.subplot(3, 1, 3) diff = x_noisy_test.squeeze().numpy()[:denoised_waveform.size(0)] - denoised_waveform.numpy() plt.plot(diff, color='g') plt.title("Difference: Noisy - Denoised") plt.xlabel("Sample Index") plt.ylabel("Amplitude") plt.tight_layout() plt.show()[1] Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow [2] FlowSep: Language-Queried Sound Separation with Rectified Flow Matching