Neural Audio Codec
This is a basic neural audio codec implementation using a convolutional encoder and decoder, along with a quantizer. The encoder compresses raw audio into a low-dimensional latent vector. A simple scalar quantizer then rounds the values to discrete levels, simulating bitrate compression. The decoder reconstructs the waveform from the quantized latent representation.
After building the convolutional encoder and decoder networks, a single-function training process (dummy, around 100 steps) is used to minimize the reconstruction loss between the input and output. Neural audio codecs are especially useful for streaming low-bitrate audio while preserving high quality.
You’ll see the training process along with a plot comparing input and output waveforms. Finally, reconstructed audio file will be written to disk. You can experiment with the latent space dimension and encoder-decoder parameters. Note that the quantizer is a basic rounding function; it can be replaced with a more advanced version to improve results.
import torch import torch.nn as nn import torchaudio import matplotlib.pyplot as plt # --- Encoder --- def build_encoder(latent_dim): return nn.Sequential( nn.Conv1d(1, 16, 3, stride=2, padding=1), nn.ReLU(), nn.Conv1d(16, 32, 3, stride=2, padding=1), nn.ReLU(), nn.Conv1d(32, latent_dim, 3, stride=2, padding=1) ) # --- Decoder --- def build_decoder(latent_dim): return nn.Sequential( nn.ConvTranspose1d(latent_dim, 32, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose1d(32, 16, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose1d(16, 1, 4, stride=2, padding=1), nn.Tanh() ) # --- Quantizer (Basic Rounding Quantizer) --- def quantize(tensor, levels=256): tensor = torch.clamp(tensor, -1.0, 1.0) step = 2.0 / levels return torch.round((tensor + 1.0) / step) * step - 1.0 # --- Forward through Codec --- def forward_codec(x, encoder, decoder): z = encoder(x) z_q = quantize(z) return decoder(z_q) # --- Training step --- def train_step(x, encoder, decoder, optimizer): encoder.train() decoder.train() optimizer.zero_grad() x_hat = forward_codec(x, encoder, decoder) loss = nn.MSELoss()(x_hat, x) loss.backward() optimizer.step() return loss.item() # --- Main --- if __name__ == "__main__": # Load input wav input_path = "Yourfile.wav" waveform, sample_rate = torchaudio.load(input_path) waveform = waveform.mean(dim=0, keepdim=True) # mono waveform = waveform[:, :sample_rate * 3] # max 3 seconds waveform = waveform.unsqueeze(0) # batch size 1 # Normalize waveform = waveform / waveform.abs().max() latent_dim = 64 encoder = build_encoder(latent_dim) decoder = build_decoder(latent_dim) optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3) # Dummy training loop for step in range(1, 101): loss = train_step(waveform, encoder, decoder, optimizer) if step % 10 == 0: print(f"Step {step:03d}, Loss: {loss:.6f}") # Inference encoder.eval() decoder.eval() with torch.no_grad(): reconstructed = forward_codec(waveform, encoder, decoder) # Save reconstructed audio torchaudio.save("Reconstructed.wav", reconstructed.squeeze(0).cpu(), sample_rate) # Plot waveforms x = waveform.squeeze().cpu().numpy() x_hat = reconstructed.squeeze().cpu().numpy() plt.figure(figsize=(12, 4)) plt.plot(x, label="Original") plt.plot(x_hat, label="Reconstructed", alpha=0.7) plt.legend() plt.title("Original vs Reconstructed Waveform") plt.show()[1] RAVE: A variational autoencoder for fast and high-quality neural audio synthesis [2] A Streamable Neural Audio Codec with Residual Scalar-Vector Quantization for Real-Time Communication [3] Facebook Research – Encodec