Pretrained Mono CNN in JUCE Plugin

—*Testing Stage* —

I tried adding a pretrained mono CNN into a JUCE plugin to experiment with real-time audio processing, mainly thinking about a vocal denoiser. The process turned out simpler than I expected, and it was fun to see AI running inside a plugin.

I started with a medium-sized CNN trained on mono audio frames. To make it run fast enough for real time, I applied int8 quantization, which made the model smaller and much quicker without losing much accuracy. Then I exported it using TorchScript, which gives a standalone file that JUCE can load directly.

On the JUCE side, loading the model is straightforward — you just point to the .pt file. Since the model was trained on mono audio, any stereo input has to be downmixed first, or you can feed only mono input. I split the audio into frames matching the CNN’s input size, ran each frame through the model, and wrote the output back to the buffer. The int8 model was smooth enough for real-time processing.

A few things I learned along the way: calibrate quantization with real audio frames when possible, double-check that frame sizes match the CNN’s input, and compare float32 and int8 outputs to make sure everything works as expected. Overall, this workflow makes it pretty easy to integrate AI-powered audio processing into a JUCE plugin without taxing the CPU.

import torch
import torch.nn as nn
import torch.quantization

#--- Example medium CNN for audio frame classification ---

class AudioCNN(nn.Module):
    def __init__(self):
        super(AudioCNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(16)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(32)
        self.fc1 = nn.Linear(32*128, 64)
        self.fc2 = nn.Linear(64, 10)  # example: 10 classes

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
#--- Convert to int8 (Post-Training Quantization) ---

# Load previously trained model by AudioCNN
model = AudioCNN()
model.load_state_dict(torch.load("audio_cnn.pth")) 
model.eval()

# Set quantization configuration
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.quantization.prepare(model, inplace=True)

# Calibration with real audio frames recommended; using dummy here
model(torch.randn(32, 1, 128))

# Convert to int8
torch.quantization.convert(model, inplace=True)

# --- 4. TorchScript export for C++ ---
scripted_model = torch.jit.script(model)
scripted_model.save("audio_cnn_int8.pt")
print("Quantized TorchScript model saved: audio_cnn_int8.pt")

#
#--- C++ side ---
#

#include <torch/script.h> // LibTorch
#include <vector>

int main() 
{
    // Load scripted model
    torch::jit::script::Module model = torch::jit::load("audio_cnn_int8.pt");
    model.eval();

    // Example audio frame (float) from JUCE callback
    std::vector<float> audioFrame(128, 0.0f);

    // Convert to tensor
    torch::Tensor input = torch::from_blob(audioFrame.data(), {1, 1, 128});

    // Run inference
    torch::Tensor output = model.forward({input}).toTensor();

    // Print output shape
    std::cout << "Output shape: " << output.sizes() << std::endl;

    return 0;
}
[1] PyTorch Quantization Guide

[2] TorchScript Documentation

[3] LibTorch (PyTorch C++ API) Guide

[4] JUCE Documentation