Arc-Length-Based Sampling in Latent Space

Here, I morph between two points (A to B), through an optional control point, and compare arc-length based sampling to linear interpolation. Main idea comes from my implementation for curve point distribution in 2D latent space with some error via arc-length calculation (on the bottom).

  1. Toy audio decoder adds two sinusoids (fundamental and third harmonic), and applies an exponential decay envelope to simulate “timbre morphing” for testing.
  2. Arc-length machinery creates a 2D cubic Bézier-like path using cubic splines and returns path derivatives of x(t), y(t). Then, it creates arc-length sample points for the path.
  3.  Morphing logic generates linear and arc-length based morphs. Arc-length ensures equal spacing in latent space, avoiding acceleration/deceleration in perception. It creates relevant audio files.
  4. Final part iteratively decodes each latent point and saves the audio files. Visualization shows the full curve with linear interpolation path (blue), arc-length sampled points (red), and deviation vectors (green) on the left. It compares Euclidean distance between steps on the right.

 

Arc-length parameterization ensures perceptual uniformity in latent space navigation. In audio synthesis, this avoids weird transitions in timbre, pitch, or loudness that can occur when interpolating directly in t-space. By spacing latent samples evenly by curve distance, we can achieve more consistent morphing, especially in nonlinear models like VAEs, GANs, or codec decoders.

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad
from scipy.interpolate import CubicSpline
from scipy.io import wavfile

# ==============================================================================
# PART 1: "TIMBRE" DECODER
# ==============================================================================
class ToyAudioDecoder:
    def __init__(self, sr=44100, duration_s=0.15):
        self.sr = sr
        self.duration_s = duration_s
        self.t = np.linspace(0, duration_s, int(sr * duration_s), endpoint=False)

    def decode(self, z):
        # Create artificial waveform for testing purposes
        x, y = z
        fundamental_freq = 100 * np.exp(1.5 * x)
        main_amp = 0.7
        harmonic_amp = (y + 1) / 2
        wave1 = main_amp * np.sin(2 * np.pi * fundamental_freq * self.t)
        wave2 = harmonic_amp * np.sin(2 * np.pi * (fundamental_freq * 3) * self.t)
        waveform = wave1 + wave2
        waveform *= np.linspace(1, 0, len(waveform))**2
        return waveform

# ==============================================================================
# PART 2: PATH DEFINITION AND ARC-LENGTH MACHINERY
# ==============================================================================
def create_path(start_z, end_z, control_point_z):
    points = np.array([start_z, control_point_z, end_z])
    t_nodes = np.linspace(0, 1, len(points))
    path_x = CubicSpline(t_nodes, points[:, 0])
    path_y = CubicSpline(t_nodes, points[:, 1])
    return path_x, path_y

def get_path_derivatives(path_x, path_y):
    return path_x.derivative(), path_y.derivative()

def path_integrand(t, dx_dt, dy_dt):
    return np.sqrt(dx_dt(t) ** 2 + dy_dt(t) ** 2)

def create_arc_length_lut_for_path(path_derivatives, t_max=1.0, num_samples=1000):
    dx_dt, dy_dt = path_derivatives
    t_samples = np.linspace(0, t_max, num_samples)
    s_values = np.zeros_like(t_samples)
    for i in range(1, num_samples):
        length, _ = quad(path_integrand, 0, t_samples[i], args=(dx_dt, dy_dt))
        s_values[i] = length
    return t_samples, s_values

# ==============================================================================
# PART 3: THE MORPHING LOGIC
# ==============================================================================
def generate_morph(path_logic, num_steps, start_z, end_z, control_z=None):
    if path_logic == 'linear':
        t_values = np.linspace(0, 1, num_steps)
        points_x = start_z[0] + t_values * (end_z[0] - start_z[0])
        points_y = start_z[1] + t_values * (end_z[1] - start_z[1])
        return np.vstack([points_x, points_y]).T
    elif path_logic == 'arc_length':
        if control_z is None:
            raise ValueError("A control point `control_z` is required for arc_length morph.")
        path_x, path_y = create_path(start_z, end_z, control_z)
        derivatives = get_path_derivatives(path_x, path_y)
        lut_t, lut_s = create_arc_length_lut_for_path(derivatives)
        total_arc_length = lut_s[-1]
        target_s = np.linspace(0, total_arc_length, num_steps)
        resampled_t = np.interp(target_s, lut_s, lut_t)
        resampled_x = path_x(resampled_t)
        resampled_y = path_y(resampled_t)
        return np.vstack([resampled_x, resampled_y]).T
    else:
        raise ValueError("Unknown path logic")

def create_audio_file(filename, latent_points, decoder):
    """Generates a full audio waveform from a sequence of latent points and saves it."""

    all_waveforms = []
    for i, z in enumerate(latent_points):
        # Decode each point into a short audio snippet
        waveform_snippet = decoder.decode(z)
        all_waveforms.append(waveform_snippet)

        # Also print the parameters for inspection
        freq = 100 * np.exp(2 * z[0])
        amp = (z[1] + 1) / 2

    # Concatenate all snippets into one long waveform
    final_waveform = np.concatenate(all_waveforms)

    # Normalize to 16-bit integer range before saving to prevent clipping
    waveform_int16 = np.int16(final_waveform / np.max(np.abs(final_waveform)) * 32767)

    # Write the WAV file
    wavfile.write(filename, decoder.sr, waveform_int16)

# ==============================================================================
# PART 4: MAIN EXECUTION
# ==============================================================================
if __name__ == "__main__":
    z_A = np.array([-1.0, -1.0])
    z_B = np.array([1.0, -1.0])
    z_Control = np.array([0.0, 1.0])
    num_morph_steps = 60

    # --- Generate Paths ---
    linear_points = generate_morph('linear', num_morph_steps, z_A, z_B)
    arc_length_points = generate_morph('arc_length', num_morph_steps, z_A, z_B, z_Control)
    difference_points = arc_length_points - linear_points

    # --- Audio Generation ---
    decoder = ToyAudioDecoder()
    create_audio_file("morph_linear.wav", linear_points, decoder)
    create_audio_file("morph_arc_length.wav", arc_length_points, decoder)
    create_audio_file("difference_morph.wav", difference_points, decoder)

    # --- VISUALIZATION BLOCK ---
    plt.figure(figsize=(14, 6))

    # --- Latent Path Plot ---
    plt.subplot(1, 2, 1)

    # Draw the high-resolution reference curve first
    path_x_ref, path_y_ref = create_path(z_A, z_B, z_Control)
    t_fine = np.linspace(0, 1, 200)
    plt.plot(path_x_ref(t_fine), path_y_ref(t_fine), 'k--', linewidth=1.5, alpha=0.6, label='Full Curved Path')

    # Plot the linear path
    plt.plot(linear_points[:, 0], linear_points[:, 1], 'bo-', markersize=5, label='Linear Path')

    # This reveals the underlying black dashed line.
    plt.plot(arc_length_points[:, 0], arc_length_points[:, 1], 'ro', markersize=6, label='Arc-Length Samples')

    # Draw deviation lines
    for p_lin, p_arc in zip(linear_points, arc_length_points):
        plt.plot([p_lin[0], p_arc[0]], [p_lin[1], p_arc[1]], 'g--', linewidth=0.8, alpha=0.7)
    plt.plot([], [], 'g--', label='Deviation')

    plt.title('Paths in "Timbre" Latent Space')
    plt.xlabel('Latent X (Fundamental Freq)')
    plt.ylabel('Latent Y (Harmonic Amplitude)')
    plt.grid(True)
    plt.legend()
    plt.gca().set_aspect('equal', adjustable='box')

    # --- Distance Plot ---
    plt.subplot(1, 2, 2)
    linear_distances = np.linalg.norm(np.diff(linear_points, axis=0), axis=1)
    arc_distances = np.linalg.norm(np.diff(arc_length_points, axis=0), axis=1)
    plt.plot(linear_distances, 'bo-', label='Distance between Linear Steps')
    plt.plot(arc_distances, 'ro-', label='Distance between Arc-Length Steps')
    plt.title('Latent Space Distance per Morph Step')
    plt.xlabel('Morph Step Number')
    plt.ylabel('Euclidean Distance in Latent Space')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

 

This curve point distribution implementation’s purpose (original idea) is to distribute a desired number of points on the curve by equal path distance to each other. To do that, I applied an adaptable unit length strategy while trying to keep the segment length error percent is less than 1% by sweeping small increments on x axis . The original idea includes errors, feeling like more humanization.

Here, I chose the curve function y = 2x^3 – x + 1 within [-1, 1] range by 20 points and 0.001 delta_x:

import numpy as np
from scipy.integrate import quad
import matplotlib.pyplot as plt

def function(x):
    # Cubic function y = 2x^3 - x + 1
    return 2 * x ** 3 - x + 1
    
def integrand(x):
    # From arc length formula: integrand = sqrt(1 + (dy/dx)^2)
    dydx = 6 * x ** 2 - 1
    return np.sqrt(1 + dydx ** 2)
    
def curve_Length(a, b):
    curveLength, _ = quad(integrand, a, b)
    return curveLength
    
def calculate_Curve_Points(numPoints, deltaX=0.001, a=-1, b=1):
    curvePointCoordinates = []
    errorList = []
    # First point as start (lower bound)
    x = a
    y = function(x)
    curvePointCoordinates.append((x, y))
    for k in range(numPoints - 1):
        distanceX = deltaX
        lowerBound = curvePointCoordinates[k][0]
        # Calculate target segment length adapting each step
        unitLength = curve_Length(lowerBound, b) / (numPoints - 1 - k)
        while True:
            upperBound = lowerBound + distanceX
            segmentLength = curve_Length(lowerBound, upperBound)
            error = 100 * abs(unitLength - segmentLength) / unitLength
            if error <= 0.1 or segmentLength >= unitLength:
                x = upperBound
                y = function(x)
                curvePointCoordinates.append((x, y))
                errorList.append(round(error, 2))
                break
            distanceX += deltaX
            
    print("Segment Errors (%):")
    for i, err in enumerate(errorList):
        print(f"Segment {i + 1}: {err}%")
        
    # Last point
    x = b
    y = function(x)
    curvePointCoordinates.append((x, y))
    
    return curvePointCoordinates
    
def plot_Curve(points, title='Curve Point Distribution'):
    x_coords = [p[0] for p in points]
    y_coords = [p[1] for p in points]
    plt.figure(dpi=300)
    plt.plot(x_coords, y_coords, marker='o', linestyle='-', color='b')
    plt.title(title)
    plt.xlabel('y = 2x³ - x + 1', color='r')
    plt.grid(True)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()
    
if __name__ == "__main__":
    numPoints = 20
    deltaX = 0.001
    a = -1.0
    b = 1.0
    curvePoints = calculate_Curve_Points(numPoints, deltaX, a, b)
    plot_Curve(curvePoints)
[1] Arc Length Formula

[2] Arc Length Parameterization

[3] Desmos calculator