Neural Audio Fingerprinting

This is a basic implementation of neural audio fingerprinting using PyTorch and torchaudio. It serves as a solid foundation for tasks like music identification, audio search, and deep audio retrieval. CNN architecture, sample rate, and input duration can be adjusted for experimentation or improved performance. Fingerprint comparison can be done using cosine similarity or Euclidean distance, depending on the application.

  • Audio loading and pre-processing (resampling, fixed-length cropping or padding)

  • Mel spectrogram extraction as input features

  • A simple convolutional neural network (CNN) to generate compact embeddings

  • L2 normalization for stable, consistent fingerprint representation

import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as t
import matplotlib.pyplot as plt

SAMPLE_RATE = 16000
DURATION = 5
NUM_SAMPLES = SAMPLE_RATE * DURATION

class AudioFingerprintNet(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(64, embedding_dim)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        x = self.fc(x)
        return x / x.norm(p=2, dim=1, keepdim=True)  # L2 normalize

def load_audio(input_path):
    waveform, sr = torchaudio.load(input_path, format="wav")
    if sr != SAMPLE_RATE:
        waveform = t.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)(waveform)
    if waveform.size(1) > NUM_SAMPLES:
        waveform = waveform[:, :NUM_SAMPLES]
    else:
        waveform = nn.functional.pad(waveform, (0, NUM_SAMPLES - waveform.size(1)))
    return waveform

def get_mel(waveform):
    mel = t.MelSpectrogram(sample_rate=SAMPLE_RATE, n_mels=64)(waveform)
    mel_db = t.AmplitudeToDB()(mel)
    return mel_db.unsqueeze(0)  # add batch dim

def generate_fingerprint(input_path):
    waveform = load_audio(input_path)
    mel = get_mel(waveform)
    model = AudioFingerprintNet()
    model.eval()
    with torch.no_grad():
        embedding = model(mel)
    return embedding.squeeze().numpy()

if __name__ == "__main__":
    fingerprint = generate_fingerprint("yourfile.wav")
    print(fingerprint.shape)  # (128,)

    plt.plot(fingerprint)
    plt.title("Audio Fingerprint Vector")
    plt.xlabel("Dimension Index")
    plt.ylabel("Value")
    plt.show()
[1] Enhancing Neural Audio Fingerprint Robustness to Audio Degradation for Music Identification

[2] Neural Audio Fingerprint for High-specific Audio Retrieval Based on Contrastive Learning

[3] Music Identification with Audio Fingerprinting. An Industrial Perspective