Neural Audio Codec
This is a basic neural audio codec implementation using a convolutional encoder and decoder, along with a quantizer. The encoder compresses raw audio into a low-dimensional latent vector. A simple scalar quantizer then rounds the values to discrete levels, simulating bitrate compression. The decoder reconstructs the waveform from the quantized latent representation.
After building the convolutional encoder and decoder networks, a single-function training process (dummy, around 100 steps) is used to minimize the reconstruction loss between the input and output. Neural audio codecs are especially useful for streaming low-bitrate audio while preserving high quality.
You’ll see the training process along with a plot comparing input and output waveforms. Finally, reconstructed audio file will be written to disk. You can experiment with the latent space dimension and encoder-decoder parameters. Note that the quantizer is a basic rounding function; it can be replaced with a more advanced version to improve results.
If you’d like to support my audio content, you can find me here:
https://buymeacoffee.com/hakanyurdakul
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
# --- Encoder ---
def build_encoder(latent_dim):
return nn.Sequential(
nn.Conv1d(1, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv1d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv1d(32, latent_dim, 3, stride=2, padding=1)
)
# --- Decoder ---
def build_decoder(latent_dim):
return nn.Sequential(
nn.ConvTranspose1d(latent_dim, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose1d(32, 16, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose1d(16, 1, 4, stride=2, padding=1),
nn.Tanh()
)
# --- Quantizer (Basic Rounding Quantizer) ---
def quantize(tensor, levels=256):
tensor = torch.clamp(tensor, -1.0, 1.0)
step = 2.0 / levels
return torch.round((tensor + 1.0) / step) * step - 1.0
# --- Forward through Codec ---
def forward_codec(x, encoder, decoder):
z = encoder(x)
z_q = quantize(z)
return decoder(z_q)
# --- Training step ---
def train_step(x, encoder, decoder, optimizer):
encoder.train()
decoder.train()
optimizer.zero_grad()
x_hat = forward_codec(x, encoder, decoder)
loss = nn.MSELoss()(x_hat, x)
loss.backward()
optimizer.step()
return loss.item()
# --- Main ---
if __name__ == "__main__":
# Load input wav
input_path = "Yourfile.wav"
waveform, sample_rate = torchaudio.load(input_path)
waveform = waveform.mean(dim=0, keepdim=True) # mono
waveform = waveform[:, :sample_rate * 3] # max 3 seconds
waveform = waveform.unsqueeze(0) # batch size 1
# Normalize
waveform = waveform / waveform.abs().max()
latent_dim = 64
encoder = build_encoder(latent_dim)
decoder = build_decoder(latent_dim)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
# Dummy training loop
for step in range(1, 101):
loss = train_step(waveform, encoder, decoder, optimizer)
if step % 10 == 0:
print(f"Step {step:03d}, Loss: {loss:.6f}")
# Inference
encoder.eval()
decoder.eval()
with torch.no_grad():
reconstructed = forward_codec(waveform, encoder, decoder)
# Save reconstructed audio
torchaudio.save("Reconstructed.wav", reconstructed.squeeze(0).cpu(), sample_rate)
# Plot waveforms
x = waveform.squeeze().cpu().numpy()
x_hat = reconstructed.squeeze().cpu().numpy()
plt.figure(figsize=(12, 4))
plt.plot(x, label="Original")
plt.plot(x_hat, label="Reconstructed", alpha=0.7)
plt.legend()
plt.title("Original vs Reconstructed Waveform")
plt.show()[1] RAVE: A variational autoencoder for fast and high-quality neural audio synthesis
[2] A Streamable Neural Audio Codec with Residual Scalar-Vector Quantization for Real-Time Communication
[3] Facebook Research – Encodec