Neural Audio Codec

This is a basic neural audio codec implementation using a convolutional encoder and decoder, along with a quantizer. The encoder compresses raw audio into a low-dimensional latent vector. A simple scalar quantizer then rounds the values to discrete levels, simulating bitrate compression. The decoder reconstructs the waveform from the quantized latent representation.

After building the convolutional encoder and decoder networks, a single-function training process (dummy, around 100 steps) is used to minimize the reconstruction loss between the input and output. Neural audio codecs are especially useful for streaming low-bitrate audio while preserving high quality.

You’ll see the training process along with a plot comparing input and output waveforms. Finally, reconstructed audio file will be written to disk. You can experiment with the latent space dimension and encoder-decoder parameters. Note that the quantizer is a basic rounding function; it can be replaced with a more advanced version to improve results.

import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt

# --- Encoder ---
def build_encoder(latent_dim):
    return nn.Sequential(
        nn.Conv1d(1, 16, 3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv1d(16, 32, 3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv1d(32, latent_dim, 3, stride=2, padding=1)
    )

# --- Decoder ---
def build_decoder(latent_dim):
    return nn.Sequential(
        nn.ConvTranspose1d(latent_dim, 32, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose1d(32, 16, 4, stride=2, padding=1),
        nn.ReLU(),
        nn.ConvTranspose1d(16, 1, 4, stride=2, padding=1),
        nn.Tanh()
    )

# --- Quantizer (Basic Rounding Quantizer) ---
def quantize(tensor, levels=256):
    tensor = torch.clamp(tensor, -1.0, 1.0)
    step = 2.0 / levels
    return torch.round((tensor + 1.0) / step) * step - 1.0

# --- Forward through Codec ---
def forward_codec(x, encoder, decoder):
    z = encoder(x)
    z_q = quantize(z)
    return decoder(z_q)

# --- Training step ---
def train_step(x, encoder, decoder, optimizer):
    encoder.train()
    decoder.train()
    optimizer.zero_grad()
    x_hat = forward_codec(x, encoder, decoder)
    loss = nn.MSELoss()(x_hat, x)
    loss.backward()
    optimizer.step()
    return loss.item()

# --- Main ---
if __name__ == "__main__":

    # Load input wav
    input_path = "Yourfile.wav"
    waveform, sample_rate = torchaudio.load(input_path)
    waveform = waveform.mean(dim=0, keepdim=True)  # mono
    waveform = waveform[:, :sample_rate * 3]  # max 3 seconds
    waveform = waveform.unsqueeze(0)  # batch size 1

    # Normalize
    waveform = waveform / waveform.abs().max()

    latent_dim = 64
    encoder = build_encoder(latent_dim)
    decoder = build_decoder(latent_dim)
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

    # Dummy training loop
    for step in range(1, 101):
        loss = train_step(waveform, encoder, decoder, optimizer)
        if step % 10 == 0:
            print(f"Step {step:03d}, Loss: {loss:.6f}")

    # Inference
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        reconstructed = forward_codec(waveform, encoder, decoder)

    # Save reconstructed audio
    torchaudio.save("Reconstructed.wav", reconstructed.squeeze(0).cpu(), sample_rate)

    # Plot waveforms
    x = waveform.squeeze().cpu().numpy()
    x_hat = reconstructed.squeeze().cpu().numpy()
    plt.figure(figsize=(12, 4))
    plt.plot(x, label="Original")
    plt.plot(x_hat, label="Reconstructed", alpha=0.7)
    plt.legend()
    plt.title("Original vs Reconstructed Waveform")
    plt.show()
[1] RAVE: A variational autoencoder for fast and high-quality neural audio synthesis

[2] A Streamable Neural Audio Codec with Residual Scalar-Vector Quantization for Real-Time Communication

[3] Facebook Research – Encodec