grafos_dsp/
stages.rs

1//! Built-in DSP processing stages.
2
3extern crate alloc;
4use alloc::vec;
5use alloc::vec::Vec;
6
7use grafos_std::error::FabricError;
8
9use crate::fft;
10use crate::types::{Block, Complex, Sample};
11
12/// Trait for a DSP processing stage that transforms blocks.
13pub trait DspStage {
14    /// Process a block of samples, returning a transformed block.
15    fn process(&mut self, block: &Block) -> Result<Block, FabricError>;
16}
17
18/// Forward FFT stage: real time-domain samples to complex frequency-domain.
19///
20/// The output block stores interleaved `[re, im, re, im, ...]` pairs.
21/// The block size must be a power of 2.
22///
23/// When the `gpu` feature is enabled, set `gpu` to `true` to dispatch FFT
24/// through `grafos-tensor`'s GPU path. Falls back to CPU if the tensor
25/// cannot be placed on GPU.
26pub struct FftStage {
27    #[cfg(feature = "gpu")]
28    gpu: bool,
29}
30
31impl FftStage {
32    /// Create a CPU-only FFT stage.
33    pub fn new() -> Self {
34        Self {
35            #[cfg(feature = "gpu")]
36            gpu: false,
37        }
38    }
39
40    /// Create an FFT stage with GPU dispatch enabled.
41    ///
42    /// When the `gpu` feature is not compiled in, this is identical to `new()`.
43    pub fn with_gpu() -> Self {
44        Self {
45            #[cfg(feature = "gpu")]
46            gpu: true,
47        }
48    }
49
50    fn process_cpu(&self, block: &Block) -> Result<Block, FabricError> {
51        let n = block.frames();
52        if n == 0 || !n.is_power_of_two() {
53            return Err(FabricError::CapacityExceeded);
54        }
55
56        let channels = block.channels as usize;
57        let mut output_data = Vec::with_capacity(n * 2 * channels);
58
59        for ch in 0..channels {
60            let channel_data: Vec<f32> = (0..n).map(|i| block.data[i * channels + ch]).collect();
61            let freq = fft::fft_real(&channel_data);
62            for c in &freq {
63                output_data.push(c.re);
64                output_data.push(c.im);
65            }
66        }
67
68        Ok(Block {
69            data: output_data,
70            sample_rate: block.sample_rate,
71            channels: block.channels,
72        })
73    }
74
75    #[cfg(feature = "gpu")]
76    fn process_gpu(&self, block: &Block) -> Result<Block, FabricError> {
77        let n = block.frames();
78        if n == 0 || !n.is_power_of_two() {
79            return Err(FabricError::CapacityExceeded);
80        }
81
82        let channels = block.channels as usize;
83        let mut output_data = Vec::with_capacity(n * 2 * channels);
84
85        for ch in 0..channels {
86            let channel_data: Vec<f32> = (0..n).map(|i| block.data[i * channels + ch]).collect();
87            let tensor = grafos_tensor::FabricTensor::from_slice(&[n], &channel_data)?;
88            let gpu_tensor = match tensor.to_gpu() {
89                Ok(t) => t,
90                Err(_) => {
91                    // GPU unavailable, fall back to CPU for this channel
92                    let freq = fft::fft_real(&channel_data);
93                    for c in &freq {
94                        output_data.push(c.re);
95                        output_data.push(c.im);
96                    }
97                    continue;
98                }
99            };
100            let freq_tensor = gpu_tensor.fft()?;
101            let cpu_result = freq_tensor.to_cpu()?;
102            output_data.extend_from_slice(cpu_result.as_slice());
103        }
104
105        Ok(Block {
106            data: output_data,
107            sample_rate: block.sample_rate,
108            channels: block.channels,
109        })
110    }
111}
112
113impl Default for FftStage {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl DspStage for FftStage {
120    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
121        #[cfg(feature = "gpu")]
122        if self.gpu {
123            return self.process_gpu(block);
124        }
125        self.process_cpu(block)
126    }
127}
128
129/// Inverse FFT stage: complex frequency-domain to real time-domain.
130///
131/// Expects input with interleaved `[re, im, re, im, ...]` pairs per channel.
132///
133/// When the `gpu` feature is enabled, set `gpu` to `true` to dispatch IFFT
134/// through `grafos-tensor`'s GPU path. Falls back to CPU if the tensor
135/// cannot be placed on GPU.
136pub struct IfftStage {
137    #[cfg(feature = "gpu")]
138    gpu: bool,
139}
140
141impl IfftStage {
142    /// Create a CPU-only IFFT stage.
143    pub fn new() -> Self {
144        Self {
145            #[cfg(feature = "gpu")]
146            gpu: false,
147        }
148    }
149
150    /// Create an IFFT stage with GPU dispatch enabled.
151    ///
152    /// When the `gpu` feature is not compiled in, this is identical to `new()`.
153    pub fn with_gpu() -> Self {
154        Self {
155            #[cfg(feature = "gpu")]
156            gpu: true,
157        }
158    }
159
160    fn process_cpu(&self, block: &Block) -> Result<Block, FabricError> {
161        let channels = block.channels as usize;
162        if channels == 0 {
163            return Err(FabricError::CapacityExceeded);
164        }
165
166        let total_per_channel = block.data.len() / channels;
167        if !total_per_channel.is_multiple_of(2) {
168            return Err(FabricError::CapacityExceeded);
169        }
170        let n_complex = total_per_channel / 2;
171        if !n_complex.is_power_of_two() {
172            return Err(FabricError::CapacityExceeded);
173        }
174
175        let mut output_data = vec![0.0f32; n_complex * channels];
176
177        for ch in 0..channels {
178            let base = ch * total_per_channel;
179            let complex_data: Vec<Complex> = (0..n_complex)
180                .map(|i| Complex::new(block.data[base + i * 2], block.data[base + i * 2 + 1]))
181                .collect();
182            let time_domain = fft::ifft_real(&complex_data);
183            for (i, &s) in time_domain.iter().enumerate() {
184                output_data[i * channels + ch] = s;
185            }
186        }
187
188        Ok(Block {
189            data: output_data,
190            sample_rate: block.sample_rate,
191            channels: block.channels,
192        })
193    }
194
195    #[cfg(feature = "gpu")]
196    fn process_gpu(&self, block: &Block) -> Result<Block, FabricError> {
197        let channels = block.channels as usize;
198        if channels == 0 {
199            return Err(FabricError::CapacityExceeded);
200        }
201
202        let total_per_channel = block.data.len() / channels;
203        if total_per_channel % 2 != 0 {
204            return Err(FabricError::CapacityExceeded);
205        }
206        let n_complex = total_per_channel / 2;
207        if !n_complex.is_power_of_two() {
208            return Err(FabricError::CapacityExceeded);
209        }
210
211        let mut output_data = vec![0.0f32; n_complex * channels];
212
213        for ch in 0..channels {
214            let base = ch * total_per_channel;
215            let freq_data: Vec<f32> = block.data[base..base + total_per_channel].to_vec();
216            let tensor = grafos_tensor::FabricTensor::from_slice(&[total_per_channel], &freq_data)?;
217            let gpu_tensor = match tensor.to_gpu() {
218                Ok(t) => t,
219                Err(_) => {
220                    // GPU unavailable, fall back to CPU for this channel
221                    let complex_data: Vec<Complex> = (0..n_complex)
222                        .map(|i| {
223                            Complex::new(block.data[base + i * 2], block.data[base + i * 2 + 1])
224                        })
225                        .collect();
226                    let time_domain = fft::ifft_real(&complex_data);
227                    for (i, &s) in time_domain.iter().enumerate() {
228                        output_data[i * channels + ch] = s;
229                    }
230                    continue;
231                }
232            };
233            let time_tensor = gpu_tensor.ifft()?;
234            let cpu_result = time_tensor.to_cpu()?;
235            let recovered = cpu_result.as_slice();
236            for (i, &s) in recovered.iter().enumerate() {
237                output_data[i * channels + ch] = s;
238            }
239        }
240
241        Ok(Block {
242            data: output_data,
243            sample_rate: block.sample_rate,
244            channels: block.channels,
245        })
246    }
247}
248
249impl Default for IfftStage {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255impl DspStage for IfftStage {
256    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
257        #[cfg(feature = "gpu")]
258        if self.gpu {
259            return self.process_gpu(block);
260        }
261        self.process_cpu(block)
262    }
263}
264
265/// FIR (Finite Impulse Response) filter stage.
266///
267/// Direct-form convolution with configurable coefficients.
268/// Maintains an overlap buffer for continuous stream processing.
269pub struct FirFilterStage {
270    coefficients: Vec<f32>,
271    /// Per-channel overlap buffer from previous blocks.
272    overlap: Vec<Vec<f32>>,
273    initialized: bool,
274}
275
276impl FirFilterStage {
277    /// Create a new FIR filter with the given coefficients.
278    pub fn new(coefficients: Vec<f32>) -> Self {
279        Self {
280            coefficients,
281            overlap: Vec::new(),
282            initialized: false,
283        }
284    }
285}
286
287impl DspStage for FirFilterStage {
288    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
289        let channels = block.channels as usize;
290        let frames = block.frames();
291        let order = self.coefficients.len();
292
293        if !self.initialized {
294            self.overlap = (0..channels).map(|_| vec![0.0; order - 1]).collect();
295            self.initialized = true;
296        }
297
298        let mut output_data = vec![0.0f32; block.data.len()];
299
300        for ch in 0..channels {
301            let input: Vec<f32> = (0..frames).map(|i| block.data[i * channels + ch]).collect();
302
303            // Prepend overlap from previous block
304            let mut extended = self.overlap[ch].clone();
305            extended.extend_from_slice(&input);
306
307            for i in 0..frames {
308                let mut sum = 0.0f32;
309                for (j, &coeff) in self.coefficients.iter().enumerate() {
310                    sum += coeff * extended[i + order - 1 - j];
311                }
312                output_data[i * channels + ch] = sum;
313            }
314
315            // Save overlap for next block
316            let overlap_start = input.len().saturating_sub(order - 1);
317            self.overlap[ch] = input[overlap_start..].to_vec();
318            // Pad with zeros if input was shorter than order-1
319            while self.overlap[ch].len() < order - 1 {
320                self.overlap[ch].insert(0, 0.0);
321            }
322        }
323
324        Ok(Block {
325            data: output_data,
326            sample_rate: block.sample_rate,
327            channels: block.channels,
328        })
329    }
330}
331
332/// IIR (Infinite Impulse Response) filter stage.
333///
334/// Direct-form II transposed implementation with configurable
335/// feedforward (b) and feedback (a) coefficients.
336pub struct IirFilterStage {
337    /// Feedforward coefficients (b[0], b[1], ..., b[M]).
338    b: Vec<f32>,
339    /// Feedback coefficients (a[1], a[2], ..., a[N]).
340    /// a[0] is assumed to be 1.0 and is not stored.
341    a: Vec<f32>,
342    /// Per-channel delay line.
343    delay: Vec<Vec<f32>>,
344    initialized: bool,
345}
346
347impl IirFilterStage {
348    /// Create a new IIR filter.
349    ///
350    /// `b` are the feedforward coefficients, `a` are the feedback coefficients.
351    /// `a[0]` is assumed to be 1.0 and should not be included in `a`.
352    pub fn new(b: Vec<f32>, a: Vec<f32>) -> Self {
353        Self {
354            b,
355            a,
356            delay: Vec::new(),
357            initialized: false,
358        }
359    }
360}
361
362impl DspStage for IirFilterStage {
363    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
364        let channels = block.channels as usize;
365        let frames = block.frames();
366        let order = core::cmp::max(self.b.len(), self.a.len() + 1);
367
368        if !self.initialized {
369            self.delay = (0..channels).map(|_| vec![0.0; order]).collect();
370            self.initialized = true;
371        }
372
373        let mut output_data = vec![0.0f32; block.data.len()];
374
375        for ch in 0..channels {
376            let delay = &mut self.delay[ch];
377
378            for i in 0..frames {
379                let x = block.data[i * channels + ch];
380
381                // Direct-form II transposed
382                let y = self.b.first().copied().unwrap_or(0.0) * x + delay[0];
383
384                // Update delay line
385                for j in 0..order - 1 {
386                    let b_term = if j + 1 < self.b.len() {
387                        self.b[j + 1] * x
388                    } else {
389                        0.0
390                    };
391                    let a_term = if j < self.a.len() { self.a[j] * y } else { 0.0 };
392                    let next = if j + 1 < delay.len() {
393                        delay[j + 1]
394                    } else {
395                        0.0
396                    };
397                    delay[j] = b_term - a_term + next;
398                }
399
400                output_data[i * channels + ch] = y;
401            }
402        }
403
404        Ok(Block {
405            data: output_data,
406            sample_rate: block.sample_rate,
407            channels: block.channels,
408        })
409    }
410}
411
412/// Gain stage: multiply all samples by a constant.
413pub struct GainStage {
414    gain: f32,
415}
416
417impl GainStage {
418    pub fn new(gain: f32) -> Self {
419        Self { gain }
420    }
421}
422
423impl DspStage for GainStage {
424    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
425        let data: Vec<Sample> = block.data.iter().map(|&s| s * self.gain).collect();
426        Ok(Block {
427            data,
428            sample_rate: block.sample_rate,
429            channels: block.channels,
430        })
431    }
432}
433
434/// Mixer stage: sums multiple input blocks into one.
435///
436/// Call `add_input` for each input block, then `mix` to produce the sum.
437pub struct MixerStage {
438    accumulated: Option<Block>,
439}
440
441impl MixerStage {
442    pub fn new() -> Self {
443        Self { accumulated: None }
444    }
445
446    /// Add an input block to the mix.
447    pub fn add_input(&mut self, block: &Block) -> Result<(), FabricError> {
448        match &mut self.accumulated {
449            None => {
450                self.accumulated = Some(block.clone());
451            }
452            Some(acc) => {
453                if acc.data.len() != block.data.len() {
454                    return Err(FabricError::CapacityExceeded);
455                }
456                for (a, &b) in acc.data.iter_mut().zip(block.data.iter()) {
457                    *a += b;
458                }
459            }
460        }
461        Ok(())
462    }
463
464    /// Produce the mixed output and reset internal state.
465    pub fn mix(&mut self) -> Result<Block, FabricError> {
466        self.accumulated.take().ok_or(FabricError::CapacityExceeded)
467    }
468}
469
470impl Default for MixerStage {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476/// DspStage implementation for MixerStage that passes blocks through
477/// (accumulating for later mixing).
478impl DspStage for MixerStage {
479    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
480        self.add_input(block)?;
481        Ok(block.clone())
482    }
483}
484
485/// Resample stage: sample rate conversion using linear interpolation.
486pub struct ResampleStage {
487    target_rate: u32,
488}
489
490impl ResampleStage {
491    pub fn new(target_rate: u32) -> Self {
492        Self { target_rate }
493    }
494}
495
496impl DspStage for ResampleStage {
497    fn process(&mut self, block: &Block) -> Result<Block, FabricError> {
498        if block.sample_rate == self.target_rate {
499            return Ok(block.clone());
500        }
501
502        let channels = block.channels as usize;
503        let in_frames = block.frames();
504        if in_frames == 0 || channels == 0 {
505            return Ok(Block {
506                data: Vec::new(),
507                sample_rate: self.target_rate,
508                channels: block.channels,
509            });
510        }
511
512        let ratio = self.target_rate as f64 / block.sample_rate as f64;
513        let out_frames = (in_frames as f64 * ratio).round() as usize;
514        let mut output_data = vec![0.0f32; out_frames * channels];
515
516        for ch in 0..channels {
517            let input: Vec<f32> = (0..in_frames)
518                .map(|i| block.data[i * channels + ch])
519                .collect();
520
521            for i in 0..out_frames {
522                let src_pos = i as f64 / ratio;
523                let src_idx = src_pos as usize;
524                let frac = src_pos - src_idx as f64;
525
526                let s0 = input[src_idx.min(in_frames - 1)];
527                let s1 = input[(src_idx + 1).min(in_frames - 1)];
528                output_data[i * channels + ch] = s0 + (s1 - s0) * frac as f32;
529            }
530        }
531
532        Ok(Block {
533            data: output_data,
534            sample_rate: self.target_rate,
535            channels: block.channels,
536        })
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use alloc::vec;
544
545    fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
546        (a - b).abs() < tol
547    }
548
549    fn mono_block(data: Vec<f32>, sample_rate: u32) -> Block {
550        Block {
551            data,
552            sample_rate,
553            channels: 1,
554        }
555    }
556
557    fn stereo_block(data: Vec<f32>, sample_rate: u32) -> Block {
558        Block {
559            data,
560            sample_rate,
561            channels: 2,
562        }
563    }
564
565    // --- FftStage ---
566
567    #[test]
568    fn fft_stage_power_of_two() {
569        let mut stage = FftStage::new();
570        let block = mono_block(vec![1.0, 0.0, -1.0, 0.0], 44100);
571        let result = stage.process(&block).unwrap();
572        // 4 input samples -> 8 output values (interleaved re/im)
573        assert_eq!(result.data.len(), 8);
574        assert_eq!(result.sample_rate, 44100);
575        assert_eq!(result.channels, 1);
576    }
577
578    #[test]
579    fn fft_stage_rejects_non_power_of_two() {
580        let mut stage = FftStage::new();
581        let block = mono_block(vec![1.0, 2.0, 3.0], 44100);
582        assert!(stage.process(&block).is_err());
583    }
584
585    #[test]
586    fn fft_stage_rejects_empty() {
587        let mut stage = FftStage::new();
588        let block = mono_block(vec![], 44100);
589        assert!(stage.process(&block).is_err());
590    }
591
592    #[test]
593    fn fft_stage_dc_signal() {
594        let mut stage = FftStage::new();
595        // Constant signal: all energy should be in bin 0
596        let block = mono_block(vec![1.0; 4], 44100);
597        let result = stage.process(&block).unwrap();
598        // Bin 0: re should be 4.0 (sum of all samples), im should be ~0
599        assert!(approx_eq(result.data[0], 4.0, 1e-4));
600        assert!(approx_eq(result.data[1], 0.0, 1e-4));
601    }
602
603    #[test]
604    fn fft_stage_default() {
605        let stage = FftStage::default();
606        let mut stage = stage;
607        let block = mono_block(vec![1.0, 0.0, 1.0, 0.0], 48000);
608        assert!(stage.process(&block).is_ok());
609    }
610
611    #[test]
612    fn fft_stage_stereo() {
613        let mut stage = FftStage::new();
614        // 2 channels, 4 frames = 8 interleaved samples
615        let block = stereo_block(vec![1.0, 2.0, 0.0, 0.0, -1.0, -2.0, 0.0, 0.0], 44100);
616        let result = stage.process(&block).unwrap();
617        // Each channel: 4 frames -> 4 complex bins -> 8 floats per channel = 16 total
618        assert_eq!(result.data.len(), 16);
619        assert_eq!(result.channels, 2);
620    }
621
622    // --- IfftStage ---
623
624    #[test]
625    fn ifft_stage_rejects_zero_channels() {
626        let mut stage = IfftStage::new();
627        let block = Block {
628            data: vec![1.0, 0.0, 1.0, 0.0],
629            sample_rate: 44100,
630            channels: 0,
631        };
632        assert!(stage.process(&block).is_err());
633    }
634
635    #[test]
636    fn ifft_stage_rejects_odd_count() {
637        let mut stage = IfftStage::new();
638        // 3 values for 1 channel is not a valid re/im pair count
639        let block = mono_block(vec![1.0, 0.0, 1.0], 44100);
640        assert!(stage.process(&block).is_err());
641    }
642
643    #[test]
644    fn ifft_stage_rejects_non_power_of_two_complex_count() {
645        let mut stage = IfftStage::new();
646        // 6 floats = 3 complex, but 3 is not a power of 2
647        let block = mono_block(vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0], 44100);
648        assert!(stage.process(&block).is_err());
649    }
650
651    #[test]
652    fn ifft_stage_default() {
653        let stage = IfftStage::default();
654        let mut stage = stage;
655        // 4 complex bins = 8 floats
656        let block = mono_block(vec![4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 44100);
657        let result = stage.process(&block).unwrap();
658        assert_eq!(result.data.len(), 4);
659    }
660
661    #[test]
662    fn fft_ifft_roundtrip_stage() {
663        let mut fft_stage = FftStage::new();
664        let mut ifft_stage = IfftStage::new();
665        let original = mono_block(vec![1.0, 2.0, 3.0, 4.0], 48000);
666        let freq = fft_stage.process(&original).unwrap();
667        let recovered = ifft_stage.process(&freq).unwrap();
668        assert_eq!(recovered.data.len(), 4);
669        for i in 0..4 {
670            assert!(
671                approx_eq(recovered.data[i], original.data[i], 1e-4),
672                "sample {i}: expected {}, got {}",
673                original.data[i],
674                recovered.data[i]
675            );
676        }
677    }
678
679    // --- FirFilterStage ---
680
681    #[test]
682    fn fir_identity_filter() {
683        // Single coefficient [1.0] = identity
684        let mut stage = FirFilterStage::new(vec![1.0]);
685        let block = mono_block(vec![1.0, 2.0, 3.0, 4.0], 44100);
686        let result = stage.process(&block).unwrap();
687        for i in 0..4 {
688            assert!(
689                approx_eq(result.data[i], block.data[i], 1e-6),
690                "sample {i}: expected {}, got {}",
691                block.data[i],
692                result.data[i]
693            );
694        }
695    }
696
697    #[test]
698    fn fir_gain_filter() {
699        // [2.0] = multiply by 2
700        let mut stage = FirFilterStage::new(vec![2.0]);
701        let block = mono_block(vec![1.0, 2.0, 3.0], 44100);
702        let result = stage.process(&block).unwrap();
703        assert!(approx_eq(result.data[0], 2.0, 1e-6));
704        assert!(approx_eq(result.data[1], 4.0, 1e-6));
705        assert!(approx_eq(result.data[2], 6.0, 1e-6));
706    }
707
708    #[test]
709    fn fir_moving_average() {
710        // 2-tap moving average: [0.5, 0.5]
711        let mut stage = FirFilterStage::new(vec![0.5, 0.5]);
712        let block = mono_block(vec![0.0, 1.0, 0.0, 1.0], 44100);
713        let result = stage.process(&block).unwrap();
714        // y[0] = 0.5*0 + 0.5*0 = 0 (overlap is zero)
715        assert!(approx_eq(result.data[0], 0.0, 1e-6));
716        // y[1] = 0.5*1 + 0.5*0 = 0.5
717        assert!(approx_eq(result.data[1], 0.5, 1e-6));
718        // y[2] = 0.5*0 + 0.5*1 = 0.5
719        assert!(approx_eq(result.data[2], 0.5, 1e-6));
720        // y[3] = 0.5*1 + 0.5*0 = 0.5
721        assert!(approx_eq(result.data[3], 0.5, 1e-6));
722    }
723
724    #[test]
725    fn fir_overlap_across_blocks() {
726        // Verify state carries across blocks
727        let mut stage = FirFilterStage::new(vec![0.5, 0.5]);
728        let block1 = mono_block(vec![1.0, 0.0], 44100);
729        let block2 = mono_block(vec![0.0, 1.0], 44100);
730
731        let _ = stage.process(&block1).unwrap();
732        let result2 = stage.process(&block2).unwrap();
733        // First sample of block2: y[0] = 0.5*0 + 0.5*0 = 0.0 (overlap from block1 tail)
734        // The overlap from block1 is [0.0] (last sample), so:
735        // y[0] = 0.5*0.0 + 0.5*0.0 = 0.0
736        assert!(approx_eq(result2.data[0], 0.0, 1e-6));
737    }
738
739    #[test]
740    fn fir_stereo() {
741        let mut stage = FirFilterStage::new(vec![1.0]);
742        // 2 channels, 2 frames: [L0, R0, L1, R1]
743        let block = stereo_block(vec![1.0, 2.0, 3.0, 4.0], 44100);
744        let result = stage.process(&block).unwrap();
745        assert_eq!(result.data.len(), 4);
746        assert!(approx_eq(result.data[0], 1.0, 1e-6)); // L0
747        assert!(approx_eq(result.data[1], 2.0, 1e-6)); // R0
748        assert!(approx_eq(result.data[2], 3.0, 1e-6)); // L1
749        assert!(approx_eq(result.data[3], 4.0, 1e-6)); // R1
750    }
751
752    // --- IirFilterStage ---
753
754    #[test]
755    fn iir_passthrough() {
756        // b=[1.0], a=[] => y = x (passthrough)
757        let mut stage = IirFilterStage::new(vec![1.0], vec![]);
758        let block = mono_block(vec![1.0, 2.0, 3.0, 4.0], 44100);
759        let result = stage.process(&block).unwrap();
760        for i in 0..4 {
761            assert!(
762                approx_eq(result.data[i], block.data[i], 1e-6),
763                "sample {i}: expected {}, got {}",
764                block.data[i],
765                result.data[i]
766            );
767        }
768    }
769
770    #[test]
771    fn iir_gain() {
772        // b=[2.0], a=[] => y = 2*x
773        let mut stage = IirFilterStage::new(vec![2.0], vec![]);
774        let block = mono_block(vec![1.0, 0.5, 0.25], 44100);
775        let result = stage.process(&block).unwrap();
776        assert!(approx_eq(result.data[0], 2.0, 1e-6));
777        assert!(approx_eq(result.data[1], 1.0, 1e-6));
778        assert!(approx_eq(result.data[2], 0.5, 1e-6));
779    }
780
781    #[test]
782    fn iir_first_order_feedback() {
783        // b=[1.0], a=[āˆ’0.5] => y[n] = x[n] + 0.5*y[nāˆ’1]
784        let mut stage = IirFilterStage::new(vec![1.0], vec![-0.5]);
785        let block = mono_block(vec![1.0, 0.0, 0.0, 0.0], 44100);
786        let result = stage.process(&block).unwrap();
787        // y[0] = 1.0 + 0 = 1.0
788        assert!(approx_eq(result.data[0], 1.0, 1e-6));
789        // y[1] = 0.0 + 0.5*1.0 = 0.5
790        assert!(approx_eq(result.data[1], 0.5, 1e-6));
791        // y[2] = 0.0 + 0.5*0.5 = 0.25
792        assert!(approx_eq(result.data[2], 0.25, 1e-6));
793        // y[3] = 0.0 + 0.5*0.25 = 0.125
794        assert!(approx_eq(result.data[3], 0.125, 1e-6));
795    }
796
797    #[test]
798    fn iir_stereo() {
799        // b=[1.0], a=[] passthrough on stereo
800        let mut stage = IirFilterStage::new(vec![1.0], vec![]);
801        let block = stereo_block(vec![1.0, 10.0, 2.0, 20.0], 44100);
802        let result = stage.process(&block).unwrap();
803        assert!(approx_eq(result.data[0], 1.0, 1e-6));
804        assert!(approx_eq(result.data[1], 10.0, 1e-6));
805        assert!(approx_eq(result.data[2], 2.0, 1e-6));
806        assert!(approx_eq(result.data[3], 20.0, 1e-6));
807    }
808
809    #[test]
810    fn iir_state_across_blocks() {
811        // b=[1.0], a=[-0.5] with impulse then zeros across two blocks
812        let mut stage = IirFilterStage::new(vec![1.0], vec![-0.5]);
813        let block1 = mono_block(vec![1.0, 0.0], 44100);
814        let result1 = stage.process(&block1).unwrap();
815        assert!(approx_eq(result1.data[0], 1.0, 1e-6));
816        assert!(approx_eq(result1.data[1], 0.5, 1e-6));
817
818        // Second block continues the decay
819        let block2 = mono_block(vec![0.0, 0.0], 44100);
820        let result2 = stage.process(&block2).unwrap();
821        assert!(approx_eq(result2.data[0], 0.25, 1e-6));
822        assert!(approx_eq(result2.data[1], 0.125, 1e-6));
823    }
824
825    // --- GainStage ---
826
827    #[test]
828    fn gain_stage_unity() {
829        let mut stage = GainStage::new(1.0);
830        let block = mono_block(vec![1.0, 2.0, 3.0], 44100);
831        let result = stage.process(&block).unwrap();
832        assert_eq!(result.data, block.data);
833    }
834
835    #[test]
836    fn gain_stage_amplify() {
837        let mut stage = GainStage::new(3.0);
838        let block = mono_block(vec![1.0, -1.0, 0.5], 44100);
839        let result = stage.process(&block).unwrap();
840        assert!(approx_eq(result.data[0], 3.0, 1e-6));
841        assert!(approx_eq(result.data[1], -3.0, 1e-6));
842        assert!(approx_eq(result.data[2], 1.5, 1e-6));
843    }
844
845    #[test]
846    fn gain_stage_attenuate() {
847        let mut stage = GainStage::new(0.5);
848        let block = mono_block(vec![4.0, 2.0], 44100);
849        let result = stage.process(&block).unwrap();
850        assert!(approx_eq(result.data[0], 2.0, 1e-6));
851        assert!(approx_eq(result.data[1], 1.0, 1e-6));
852    }
853
854    #[test]
855    fn gain_stage_zero() {
856        let mut stage = GainStage::new(0.0);
857        let block = mono_block(vec![100.0, -50.0], 44100);
858        let result = stage.process(&block).unwrap();
859        assert!(approx_eq(result.data[0], 0.0, 1e-6));
860        assert!(approx_eq(result.data[1], 0.0, 1e-6));
861    }
862
863    #[test]
864    fn gain_stage_negative() {
865        let mut stage = GainStage::new(-1.0);
866        let block = mono_block(vec![1.0, -2.0], 44100);
867        let result = stage.process(&block).unwrap();
868        assert!(approx_eq(result.data[0], -1.0, 1e-6));
869        assert!(approx_eq(result.data[1], 2.0, 1e-6));
870    }
871
872    #[test]
873    fn gain_stage_preserves_metadata() {
874        let mut stage = GainStage::new(2.0);
875        let block = Block {
876            data: vec![1.0],
877            sample_rate: 96000,
878            channels: 3,
879        };
880        let result = stage.process(&block).unwrap();
881        assert_eq!(result.sample_rate, 96000);
882        assert_eq!(result.channels, 3);
883    }
884
885    // --- MixerStage ---
886
887    #[test]
888    fn mixer_single_input() {
889        let mut mixer = MixerStage::new();
890        let block = mono_block(vec![1.0, 2.0, 3.0], 44100);
891        mixer.add_input(&block).unwrap();
892        let result = mixer.mix().unwrap();
893        assert_eq!(result.data, block.data);
894    }
895
896    #[test]
897    fn mixer_two_inputs() {
898        let mut mixer = MixerStage::new();
899        let a = mono_block(vec![1.0, 2.0, 3.0], 44100);
900        let b = mono_block(vec![10.0, 20.0, 30.0], 44100);
901        mixer.add_input(&a).unwrap();
902        mixer.add_input(&b).unwrap();
903        let result = mixer.mix().unwrap();
904        assert!(approx_eq(result.data[0], 11.0, 1e-6));
905        assert!(approx_eq(result.data[1], 22.0, 1e-6));
906        assert!(approx_eq(result.data[2], 33.0, 1e-6));
907    }
908
909    #[test]
910    fn mixer_mismatched_lengths() {
911        let mut mixer = MixerStage::new();
912        let a = mono_block(vec![1.0, 2.0], 44100);
913        let b = mono_block(vec![1.0, 2.0, 3.0], 44100);
914        mixer.add_input(&a).unwrap();
915        assert!(mixer.add_input(&b).is_err());
916    }
917
918    #[test]
919    fn mixer_empty_mix() {
920        let mut mixer = MixerStage::new();
921        assert!(mixer.mix().is_err());
922    }
923
924    #[test]
925    fn mixer_resets_after_mix() {
926        let mut mixer = MixerStage::new();
927        let block = mono_block(vec![1.0], 44100);
928        mixer.add_input(&block).unwrap();
929        let _ = mixer.mix().unwrap();
930        // After mix(), internal state is cleared
931        assert!(mixer.mix().is_err());
932    }
933
934    #[test]
935    fn mixer_default() {
936        let mut mixer = MixerStage::default();
937        let block = mono_block(vec![5.0], 44100);
938        mixer.add_input(&block).unwrap();
939        let result = mixer.mix().unwrap();
940        assert!(approx_eq(result.data[0], 5.0, 1e-6));
941    }
942
943    #[test]
944    fn mixer_dsp_stage_passes_through() {
945        let mut mixer = MixerStage::new();
946        let block = mono_block(vec![1.0, 2.0], 44100);
947        // DspStage::process accumulates and returns a clone
948        let result = mixer.process(&block).unwrap();
949        assert_eq!(result.data, block.data);
950        // The accumulated value is available via mix()
951        let mixed = mixer.mix().unwrap();
952        assert_eq!(mixed.data, block.data);
953    }
954
955    // --- ResampleStage ---
956
957    #[test]
958    fn resample_same_rate() {
959        let mut stage = ResampleStage::new(44100);
960        let block = mono_block(vec![1.0, 2.0, 3.0, 4.0], 44100);
961        let result = stage.process(&block).unwrap();
962        assert_eq!(result.data, block.data);
963        assert_eq!(result.sample_rate, 44100);
964    }
965
966    #[test]
967    fn resample_upsample_2x() {
968        let mut stage = ResampleStage::new(96000);
969        let block = mono_block(vec![0.0, 1.0], 48000);
970        let result = stage.process(&block).unwrap();
971        assert_eq!(result.sample_rate, 96000);
972        // 2 frames at 48k -> 4 frames at 96k
973        assert_eq!(result.data.len(), 4);
974        // First sample should be ~0.0
975        assert!(approx_eq(result.data[0], 0.0, 1e-4));
976        // Last sample should be ~1.0
977        assert!(approx_eq(result.data[3], 1.0, 1e-4));
978    }
979
980    #[test]
981    fn resample_downsample_2x() {
982        let mut stage = ResampleStage::new(22050);
983        let block = mono_block(vec![0.0, 0.25, 0.5, 0.75], 44100);
984        let result = stage.process(&block).unwrap();
985        assert_eq!(result.sample_rate, 22050);
986        // 4 frames at 44100 -> 2 frames at 22050
987        assert_eq!(result.data.len(), 2);
988    }
989
990    #[test]
991    fn resample_empty_block() {
992        let mut stage = ResampleStage::new(96000);
993        let block = mono_block(vec![], 44100);
994        let result = stage.process(&block).unwrap();
995        assert!(result.data.is_empty());
996        assert_eq!(result.sample_rate, 96000);
997    }
998
999    #[test]
1000    fn resample_stereo() {
1001        let mut stage = ResampleStage::new(96000);
1002        // 2 channels, 2 frames at 48k
1003        let block = stereo_block(vec![0.0, 10.0, 1.0, 20.0], 48000);
1004        let result = stage.process(&block).unwrap();
1005        assert_eq!(result.sample_rate, 96000);
1006        assert_eq!(result.channels, 2);
1007        // 2 frames -> 4 frames, 2 channels = 8 samples
1008        assert_eq!(result.data.len(), 8);
1009    }
1010
1011    #[test]
1012    fn resample_preserves_dc() {
1013        // A constant signal should remain constant after resampling
1014        let mut stage = ResampleStage::new(96000);
1015        let block = mono_block(vec![5.0; 4], 48000);
1016        let result = stage.process(&block).unwrap();
1017        for &sample in &result.data {
1018            assert!(approx_eq(sample, 5.0, 1e-4));
1019        }
1020    }
1021}