Source code

Revision control

Copy as Markdown

Other Tools

/*
* Copyright (c) 2025 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
/*
Python script used to generate the test data:
import numpy as np
from typing import List
def python_feature_extractor(time_frame: np.ndarray) -> np.ndarray:
frame_length: int = 256
sqrt_hann: np.ndarray = np.sqrt(np.hanning(frame_length))
magnitude_spectrum: np.ndarray = np.abs(np.fft.rfft(time_frame * sqrt_hann))
return np.power(magnitude_spectrum + 1e-8, 0.3)
def format_as_cpp_array(data: np.ndarray, name: str) -> str:
elements_per_line = 6
s = f"constexpr float {name}[] = {{\n "
for i, x in enumerate(data):
s += f"{x:.8f}, "
if (i + 1) % elements_per_line == 0 and i < len(data) - 1:
s += "\n "
s = s.rstrip(", ") + "\n};"
return s
# Generate two frames of white noise
np.random.seed(0) # for reproducibility
noise1: np.ndarray = np.random.uniform(-1.0, 1.0, 256)
noise2: np.ndarray = np.random.uniform(-1.0, 1.0, 256)
# Scale to match the C++ implementation's expected input range
noise1_scaled: np.ndarray = noise1 * 32768.0
noise2_scaled: np.ndarray = noise2 * 32768.0
# Python equivalent
expected_output1: np.ndarray = python_feature_extractor(noise1)
expected_output2: np.ndarray = python_feature_extractor(noise2)
print(format_as_cpp_array(noise1_scaled, "noise1_scaled"))
print(format_as_cpp_array(noise2_scaled, "noise2_scaled"))
print(format_as_cpp_array(expected_output1, "expected_output1"))
print(format_as_cpp_array(expected_output2, "expected_output2"))
*/
#include "modules/audio_processing/aec3/neural_feature_extractor.h"
#include <vector>
#include "test/gmock.h"
#include "test/gtest.h"
namespace webrtc {
namespace {
using ::testing::FloatNear;
using ::testing::Pointwise;
constexpr float kTolerance = 1e-6f;
// Test data generated from the Python implementation.
constexpr float noise1_scaled[] = {
3199.0418, 14102.6503, 6734.7006, 2941.4643, -5003.3591,
9561.3166, -4090.2845, 25675.2354, 30386.6027, -7638.7766,
19118.4921, 1893.6575, 4459.3684, 27891.9013, -28112.5809,
-27057.8942, -31442.9671, 21798.5742, 18229.2808, 24249.1161,
31366.7317, 19605.6557, -2524.4885, 18384.7601, -25016.7672,
9169.8641, -23373.1990, 29141.8221, 1431.8516, -5592.7151,
-15430.0834, 17972.1791, -2873.7318, 4484.8873, -31536.5916,
7709.3599, 7346.3053, 7663.3864, 29081.4741, 11915.7751,
-9207.2902, -4126.6739, 12951.9581, -28821.0635, 10929.2235,
11182.9234, -18980.3685, -24318.6862, -12096.0876, -8931.8509,
4600.4155, -4023.8112, 32006.0679, -26080.3913, -19079.0529,
-22196.4194, 10034.1072, -16168.2815, -2207.8572, -16749.3244,
-22349.7694, -25534.4547, 10245.2160, -23712.0421, -19884.7783,
-8603.2272, 21036.6123, -26404.3708, 22147.5575, -26470.0947,
31225.2475, -2054.4748, 31245.0147, 6871.1560, 15680.3779,
-30199.7888, -14233.9629, -24890.7982, -13360.1560, -24987.0602,
-11928.6544, -5618.8604, -28564.0297, 12613.8528, 4364.7929,
-15375.4343, 1523.5844, -26611.5147, 4977.2295, 28134.3556,
-11890.2651, 10971.4067, -24130.4953, 14177.2196, -13801.4823,
-20762.3709, 5669.7117, -31450.2319, 21557.4138, -32460.2773,
11653.3846, -15072.7575, 15413.6754, 30289.9885, -16465.7140,
4991.0471, 6032.0600, 4735.1009, -18148.1221, 29671.3592,
-3465.1912, 22702.2388, 13073.0738, -13275.1720, 20565.0539,
-6782.5998, 24975.9791, 5326.2990, 25017.4087, 12617.7503,
14762.2645, 86.7947, 29889.8971, 9436.5417, -4990.2355,
6972.5857, -31510.1546, -13003.9928, 10497.1330, -13757.4739,
7734.2592, -4668.2144, -23889.5717, -13219.7695, 4585.2204,
5955.4373, 4870.9795, 10040.1689, 9968.2399, -4494.5614,
25988.0777, -8679.4653, -4203.1563, 25685.0890, 20066.7293,
13362.0422, -26199.5307, 27491.2126, 14040.5178, 32692.4374,
-22973.7559, 24125.5093, -22118.8630, 7573.3116, -24653.3336,
22807.0673, 20140.4553, 4528.5860, -6082.8354, -28235.0718,
12938.6921, -3044.6267, 14552.6358, 24011.2321, 31163.7774,
23317.9278, -32000.3058, -9176.4776, 15072.6615, -21520.0775,
1378.6550, -29206.9056, -19661.0277, -31554.1557, 19247.7727,
-18092.8716, -10135.0323, 28054.7356, 13396.5022, -30681.4039,
-21974.6038, 7961.2085, 5061.2528, -17177.4561, 28456.6486,
7468.8729, 2335.2314, 5892.3402, 15081.2773, -12324.3728,
-6670.1845, -19015.6801, -20565.6552, 29122.3889, 15699.2009,
-625.2915, -17864.1549, -16098.4936, -28965.0009, -4298.0720,
-12334.1451, 12867.5669, -8011.6555, -20997.4934, -31150.6549,
-28360.7282, 11756.6848, -3034.5236, 2397.2552, 25996.2499,
32134.8533, -18553.4392, 10687.4931, -15510.9047, -31414.6161,
16933.1035, -11795.3560, -7637.3102, 5787.9504, 21695.5916,
8452.9541, 24422.0334, -14841.1492, 19532.7973, -20602.1628,
29674.1540, 12287.2317, -18644.4889, 29318.8790, 15129.3662,
-16125.6805, -18788.3863, 1192.8020, -31086.1681, -19171.2411,
-4935.8131, -8246.3962, -2387.1210, -14573.3251, 5687.4989,
23845.6410, -25065.4323, 1138.9572, -24112.7846, 14212.1161,
-6811.8313, 4287.4511, -20756.5727, -23275.2572, -782.7436,
-9462.5636, 28864.1480, 17388.3558, 16296.4190, 26458.1769,
-27300.8273};
constexpr float noise2_scaled[] = {
3420.4857, 5536.2237, 30273.4625, -13621.8197, -16985.0451,
-26195.1362, -31691.2678, 28149.6333, 11135.6508, 18687.7812,
-14304.5358, 5662.9767, -28576.6277, -941.9099, 31293.1215,
24674.6478, -10606.4149, 30249.4616, -17583.2022, 29446.5583,
28926.1293, 19608.5408, 8549.0360, 24529.3362, -13564.6226,
22868.3648, 7725.1669, -31900.5093, -10011.7042, -23059.4405,
31577.1709, -1417.5236, -170.9595, 9140.4708, -8612.4393,
-23796.1038, 21110.3078, -20326.1272, 741.8008, -18067.1592,
-26355.6639, 23736.5833, 30993.2516, 30201.2602, 26644.0212,
17959.9660, -10934.9993, -27452.9393, -6079.0426, -17548.3033,
-24085.2904, -29266.5962, 14784.5523, -32019.0901, 17732.7799,
-23137.7046, -27556.4408, -26895.7755, 11275.3251, -16687.6145,
-5207.5255, 3759.7211, 23629.0817, 14879.5728, -15051.7904,
-24151.1433, -29138.9885, -13002.4319, -15589.8250, -2874.3718,
12011.5256, 12820.5092, -14187.3089, -7869.1070, -20896.0906,
18910.1187, -29042.4045, 12910.4112, 18264.5815, 18180.1820,
-15766.4828, -8269.7822, 5740.9297, -14888.3438, -8463.7910,
-19853.8507, -2630.8848, -29844.2882, 19647.4231, -27724.5823,
1234.3803, -12660.8933, 5081.8547, 30109.4234, 9540.0915,
-30450.4874, -4561.1457, 656.4644, 2370.9283, 11887.7396,
-14575.4621, -24322.9940, -7033.6069, 29911.0054, -20504.1899,
26475.4925, 2870.8667, -2823.8531, 25037.4659, -2712.9308,
14691.0502, -6617.4765, 26479.4533, 12453.4797, 13082.4309,
-11290.5158, 16828.2451, 8916.8973, -17038.0314, -22246.9277,
19424.3117, 30091.9425, -2743.4138, 5962.7383, 23443.7112,
-2803.4038, 29614.0457, 4964.4282, 21021.7940, 26793.9819,
20678.1690, -22320.6137, 8447.4881, -6656.2124, -28658.0440,
-4978.6223, -15814.8810, 22874.5746, -30585.3480, 30079.8917,
-9478.5471, -9390.8572, -31697.8952, -20628.6143, -6471.0574,
28134.0423, -26239.6359, 29183.2813, 24214.8003, -3004.0132,
-11357.3310, -17514.8807, 7501.5590, -30600.4236, -31745.2410,
-4666.4435, -28306.6975, -16256.7954, -18273.9983, -16174.8619,
-24179.1644, -31979.1941, -25199.6211, 7764.7223, 31080.8552,
32135.2500, -5960.2308, -22088.6187, 9093.8905, -635.3488,
32073.9592, -28488.2235, 18562.0521, -13867.5161, -16946.3893,
10649.8996, -16642.0031, 10869.7431, 1134.3310, -4974.9041,
3584.0202, -13955.7916, 13538.0800, -5579.9402, -9139.2861,
21538.8596, 27850.6315, -29752.8649, -17522.5574, -9927.4346,
20641.6432, 31817.1662, 30734.5296, 26538.6948, -13332.8886,
32244.4488, -16422.0082, -25827.3342, 29553.6303, -17470.5701,
12436.6530, -28943.5577, 15119.7515, 25016.4158, -14913.5756,
-7926.1273, -8238.1253, 16304.5872, -17183.0646, -21505.4353,
-3323.2225, -12814.3585, 22229.0983, -17187.3517, 156.5955,
29005.1588, 8781.6731, 24070.6785, 28849.5822, 16434.1260,
13079.3511, 30668.5914, 32401.0502, -3157.4142, -28123.4782,
-13579.4504, -22783.2820, -5407.6129, -24163.8226, 6823.4644,
-7680.2910, 25912.0093, 30657.3916, 3072.6489, -14757.1625,
6044.4127, 26002.1393, -6112.3234, 3413.0019, -14964.9642,
-2920.0122, -6441.3017, -16487.9752, 384.4593, -12426.8822,
-8320.7872, 1636.4629, 16422.9954, -10911.2547, 27797.6689,
23744.9083, -29577.0328, -16145.2835, -3530.0630, -25911.1067,
-9930.2776};
constexpr float expected_output1[] = {
1.2202633, 1.7272304, 2.1753535, 1.6639511, 1.4504087, 1.5607018, 1.7948899,
2.0986237, 1.5073231, 0.8968942, 1.8699082, 2.0595039, 1.6333632, 1.6918161,
1.7579499, 1.8817867, 1.9033698, 1.9425666, 1.4443926, 1.1040287, 1.4225169,
1.1535512, 1.4586320, 1.6414003, 1.7915095, 1.6791658, 1.4245994, 1.5361380,
2.0224556, 1.4529938, 1.3677414, 1.6004674, 1.7868200, 1.3359326, 1.9621301,
1.4692749, 1.6836248, 1.6219408, 1.6542681, 2.2417320, 1.9120614, 1.6369832,
1.1818825, 1.3819567, 1.3740384, 1.3323426, 1.7350840, 1.5579263, 1.1322017,
1.2045572, 1.6530098, 1.6843505, 1.2226551, 1.4986210, 1.5158652, 1.5288519,
1.4476088, 1.7631098, 1.4404006, 1.0171719, 1.7696546, 2.0226616, 2.0523162,
1.5416487, 1.5385250, 1.0534991, 1.3605192, 1.4166694, 1.6238999, 1.6638377,
1.4028377, 1.6349643, 1.5471496, 1.5039228, 1.3435868, 2.0315477, 1.6629901,
1.5412650, 1.7623193, 1.8761405, 1.5532731, 1.9655503, 1.9347810, 1.4526643,
0.9392141, 1.5384618, 1.5229951, 1.3041083, 1.2288715, 1.5890568, 1.4367742,
1.8774723, 1.7158524, 1.5562983, 1.8137322, 1.3629094, 1.5521119, 1.5687853,
1.6626421, 1.8479395, 1.6954730, 1.5309387, 1.5702729, 1.8073848, 1.8335479,
1.7042632, 1.6445036, 1.6976509, 1.1417790, 1.5238974, 2.1945088, 2.1619850,
1.7098370, 1.6523124, 1.9440371, 1.8486495, 2.0672979, 2.1444454, 1.3407502,
1.6834193, 1.8205304, 1.8301605, 1.6322767, 1.9723609, 2.1829500, 1.4344222,
1.9528573, 1.8263685, 1.7367752};
constexpr float expected_output2[] = {
1.4387245, 1.4403429, 1.9777731, 2.2877072, 1.8727017, 1.9210667, 1.3046225,
1.4128947, 1.1877113, 1.4548492, 1.9006205, 1.8816212, 2.1369724, 1.5365391,
2.0703334, 2.1232226, 1.6342221, 1.3652948, 1.7604186, 1.9382242, 1.3772832,
1.4379133, 1.7277498, 1.4895349, 1.7578079, 1.3160488, 1.4388874, 2.1002045,
1.6475185, 2.0833502, 2.0610020, 1.5717945, 1.2313899, 1.8931674, 1.4281442,
1.6715665, 1.9357384, 1.0737266, 1.8492249, 1.6154146, 1.7171611, 1.2594775,
1.6492247, 1.7053658, 1.7071571, 1.4533884, 1.8833118, 1.6621069, 1.7046046,
1.3899836, 1.9614496, 1.5187381, 1.5449415, 1.7357930, 1.7269029, 1.4725356,
1.8128068, 1.3997501, 1.7827980, 1.6610097, 1.8852335, 1.2815337, 2.0144823,
1.5584240, 1.3150680, 2.0394259, 2.0875142, 1.7976422, 1.4628408, 1.6218163,
1.5199225, 1.2747693, 0.7672618, 1.6249810, 1.8661000, 1.7897703, 1.3933436,
1.6946455, 1.6239630, 1.5750381, 1.6310876, 1.1158317, 1.7192902, 1.5997313,
1.9059273, 1.8387797, 1.8939278, 1.5037787, 1.3948746, 1.6324782, 1.2768368,
1.5373271, 1.4816556, 1.3888880, 1.5884829, 1.3518394, 1.8841741, 1.4184990,
1.3060136, 1.8983842, 1.9823723, 1.9028293, 1.9833195, 2.0924404, 1.3062575,
1.4686721, 1.5963211, 1.4120681, 1.5727397, 1.3416973, 2.0295401, 1.7874475,
1.3433882, 1.7285510, 1.9679515, 2.0707020, 1.9011226, 1.6455043, 1.5474962,
1.8804304, 2.0154081, 2.0223212, 1.8945941, 1.4271733, 2.0051481, 1.7717120,
1.6581700, 1.2560840, 1.4274090};
TEST(NeuralFeatureExtractorTest, FrequencyDomainFeatureExtractor) {
// Initialize the feature extractor.
constexpr int kStepSize = 128;
FrequencyDomainFeatureExtractor extractor(kStepSize);
// Input and output buffers.
std::vector<float> input(kStepSize);
std::vector<float> output(kStepSize + 1);
// First frame.
for (int i = 0; i < kStepSize; ++i) {
input[i] = noise1_scaled[i];
}
extractor.PushFeaturesToModelInput(input, output);
for (int i = 0; i < kStepSize; ++i) {
input[i] = noise1_scaled[i + kStepSize];
}
extractor.PushFeaturesToModelInput(input, output);
// Compare the output with the expected output.
EXPECT_THAT(output, Pointwise(FloatNear(kTolerance), expected_output1));
// Second frame.
for (int i = 0; i < kStepSize; ++i) {
input[i] = noise2_scaled[i];
}
extractor.PushFeaturesToModelInput(input, output);
for (int i = 0; i < kStepSize; ++i) {
input[i] = noise2_scaled[i + kStepSize];
}
extractor.PushFeaturesToModelInput(input, output);
// Compare the output with the expected output.
EXPECT_THAT(output, Pointwise(FloatNear(kTolerance), expected_output2));
}
} // namespace
} // namespace webrtc