ESPHome 2025.5.0
Loading...
Searching...
No Matches
micro_wake_word.cpp
Go to the documentation of this file.
1#include "micro_wake_word.h"
2
3#ifdef USE_ESP_IDF
4
5#include "esphome/core/hal.h"
7#include "esphome/core/log.h"
8
10
11#ifdef USE_OTA
13#endif
14
15namespace esphome {
16namespace micro_wake_word {
17
18static const char *const TAG = "micro_wake_word";
19
20static const ssize_t DETECTION_QUEUE_LENGTH = 5;
21
22static const size_t DATA_TIMEOUT_MS = 50;
23
24static const uint32_t RING_BUFFER_DURATION_MS = 120;
25
26static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
27static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
28
29enum EventGroupBits : uint32_t {
30 COMMAND_STOP = (1 << 0), // Signals the inference task should stop
31
32 TASK_STARTING = (1 << 3),
33 TASK_RUNNING = (1 << 4),
34 TASK_STOPPING = (1 << 5),
35 TASK_STOPPED = (1 << 6),
36
37 ERROR_MEMORY = (1 << 9),
38 ERROR_INFERENCE = (1 << 10),
39
41
43 ALL_BITS = 0xfffff, // 24 total bits available in an event group
44};
45
47
48static const LogString *micro_wake_word_state_to_string(State state) {
49 switch (state) {
50 case State::STARTING:
51 return LOG_STR("STARTING");
53 return LOG_STR("DETECTING_WAKE_WORD");
54 case State::STOPPING:
55 return LOG_STR("STOPPING");
56 case State::STOPPED:
57 return LOG_STR("STOPPED");
58 default:
59 return LOG_STR("UNKNOWN");
60 }
61}
62
64 ESP_LOGCONFIG(TAG, "microWakeWord:");
65 ESP_LOGCONFIG(TAG, " models:");
66 for (auto &model : this->wake_word_models_) {
67 model->log_model_config();
68 }
69#ifdef USE_MICRO_WAKE_WORD_VAD
70 this->vad_model_->log_model_config();
71#endif
72}
73
75 ESP_LOGCONFIG(TAG, "Setting up microWakeWord...");
76
77 this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
78 this->frontend_config_.window.step_size_ms = this->features_step_size_;
79 this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
80 this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
81 this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
82 this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
83 this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
84 this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
85 this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
86 this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
87 this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
88 this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
89 this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
90 this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
91 this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
92
93 this->event_group_ = xEventGroupCreate();
94 if (this->event_group_ == nullptr) {
95 ESP_LOGE(TAG, "Failed to create event group");
96 this->mark_failed();
97 return;
98 }
99
100 this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
101 if (this->detection_queue_ == nullptr) {
102 ESP_LOGE(TAG, "Failed to create detection event queue");
103 this->mark_failed();
104 return;
105 }
106
107 this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
108 if (this->state_ == State::STOPPED) {
109 return;
110 }
111 std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
112 if (this->ring_buffer_.use_count() > 1) {
113 size_t bytes_free = temp_ring_buffer->free();
114
115 if (bytes_free < data.size()) {
117 temp_ring_buffer->reset();
118 }
119 temp_ring_buffer->write((void *) data.data(), data.size());
120 }
121 });
122
123#ifdef USE_OTA
125 [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
126 if (state == ota::OTA_STARTED) {
127 this->suspend_task_();
128 } else if (state == ota::OTA_ERROR) {
129 this->resume_task_();
130 }
131 });
132#endif
133 ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
134}
135
137 MicroWakeWord *this_mww = (MicroWakeWord *) params;
138
139 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
140
141 { // Ensures any C++ objects fall out of scope to deallocate before deleting the task
142
143 const size_t new_bytes_to_process =
145 std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
146 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
147
148 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
149 // Allocate audio transfer buffer
150 audio_buffer = audio::AudioSourceTransferBuffer::create(new_bytes_to_process);
151
152 if (audio_buffer == nullptr) {
153 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
154 }
155 }
156
157 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
158 // Allocate ring buffer
159 std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(
160 this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS));
161 if (temp_ring_buffer.use_count() == 0) {
162 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
163 }
164 audio_buffer->set_source(temp_ring_buffer);
165 this_mww->ring_buffer_ = temp_ring_buffer;
166 }
167
168 if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
169 this_mww->microphone_source_->start();
170 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
171
172 while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
173 audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
174
175 if (audio_buffer->available() < new_bytes_to_process) {
176 // Insufficient data to generate new spectrogram features, read more next iteration
177 continue;
178 }
179
180 // Generate new spectrogram features
181 uint32_t processed_samples = this_mww->generate_features_(
182 (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
183 audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
184
185 // Run inference using the new spectorgram features
186 if (!this_mww->update_model_probabilities_(features_buffer)) {
187 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
188 break;
189 }
190
191 // Process each model's probabilities and possibly send a Detection Event to the queue
192 this_mww->process_probabilities_();
193 }
194 }
195 }
196
197 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
198
199 this_mww->unload_models_();
200 this_mww->microphone_source_->stop();
201 FrontendFreeStateContents(&this_mww->frontend_state_);
202
203 xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
204 while (true) {
205 // Continuously delay until the main loop deletes the task
206 delay(10);
207 }
208}
209
210std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
211 std::vector<WakeWordModel *> external_wake_word_models;
212 for (auto *model : this->wake_word_models_) {
213 if (!model->get_internal_only()) {
214 external_wake_word_models.push_back(model);
215 }
216 }
217 return external_wake_word_models;
218}
219
221
222#ifdef USE_MICRO_WAKE_WORD_VAD
223void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
224 size_t tensor_arena_size) {
225 this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
226}
227#endif
228
230 if (this->inference_task_handle_ != nullptr) {
231 vTaskSuspend(this->inference_task_handle_);
232 }
233}
234
236 if (this->inference_task_handle_ != nullptr) {
237 vTaskResume(this->inference_task_handle_);
238 }
239}
240
242 uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
243
244 if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
245 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
246 ESP_LOGE(TAG, "Encountered an error allocating buffers");
247 }
248
249 if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
250 xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
251 ESP_LOGE(TAG, "Encountered an error while performing an inference");
252 }
253
254 if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
255 xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
256 ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
257 "word detection accuracy will temporarily be reduced.");
258 }
259
260 if (event_group_bits & EventGroupBits::TASK_STARTING) {
261 ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
262 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
263 }
264
265 if (event_group_bits & EventGroupBits::TASK_RUNNING) {
266 ESP_LOGD(TAG, "Inference task is running");
267
268 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
270 }
271
272 if (event_group_bits & EventGroupBits::TASK_STOPPING) {
273 ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
274 xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
275 }
276
277 if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
278 ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
279 vTaskDelete(this->inference_task_handle_);
280 this->inference_task_handle_ = nullptr;
281 xEventGroupClearBits(this->event_group_, ALL_BITS);
282 xQueueReset(this->detection_queue_);
284 }
285
286 if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
288 this->pending_start_ = false;
289 }
290
291 if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
293 this->pending_stop_ = false;
294 }
295
296 switch (this->state_) {
297 case State::STARTING:
298 if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
299 // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
300 // uses floating point operations.
301 if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_,
304 "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000);
305 return;
306 }
307
308 xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
309 INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
310
311 if (this->inference_task_handle_ == nullptr) {
312 FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
313 this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000);
314 }
315 }
316 break;
318 DetectionEvent detection_event;
319 while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
320 if (detection_event.blocked_by_vad) {
321 ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
322 } else {
323 constexpr float uint8_to_float_divisor =
324 255.0f; // Converting a quantized uint8 probability to floating point
325 ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
326 detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
327 (detection_event.max_probability / uint8_to_float_divisor));
328 this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
329 if (this->stop_after_detection_) {
330 this->stop();
331 }
332 }
333 }
334 break;
335 }
336 case State::STOPPING:
337 xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
338 break;
339 case State::STOPPED:
340 break;
341 }
342}
343
345 if (!this->is_ready()) {
346 ESP_LOGW(TAG, "Wake word detection can't start as the component hasn't been setup yet");
347 return;
348 }
349
350 if (this->is_failed()) {
351 ESP_LOGW(TAG, "Wake word component is marked as failed. Please check setup logs");
352 return;
353 }
354
355 if (this->is_running()) {
356 ESP_LOGW(TAG, "Wake word detection is already running");
357 return;
358 }
359
360 ESP_LOGD(TAG, "Starting wake word detection");
361
362 this->pending_start_ = true;
363 this->pending_stop_ = false;
364}
365
367 if (this->state_ == STOPPED)
368 return;
369
370 ESP_LOGD(TAG, "Stopping wake word detection");
371
372 this->pending_start_ = false;
373 this->pending_stop_ = true;
374}
375
377 if (this->state_ != state) {
378 ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
379 LOG_STR_ARG(micro_wake_word_state_to_string(state)));
380 this->state_ = state;
381 }
382}
383
384size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
385 int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
386 size_t processed_samples = 0;
387 struct FrontendOutput frontend_output =
388 FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
389
390 for (size_t i = 0; i < frontend_output.size; ++i) {
391 // These scaling values are set to match the TFLite audio frontend int8 output.
392 // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670
393 // range. In training, these are then arbitrarily divided by 25.6 to get
394 // float values in the rough range of 0.0 to 26.0. This scaling is performed
395 // for historical reasons, to match up with the output of other feature
396 // generators.
397 // The process is then further complicated when we quantize the model. This
398 // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
399 // to 127 (INT8_MAX) signed integer numbers.
400 // All this means that to get matching values from our integer feature
401 // output into the tensor input, we have to perform:
402 // input = (((feature / 25.6) / 26.0) * 256) - 128
403 // To simplify this and perform it in 32-bit integer math, we rearrange to:
404 // input = (feature * 256) / (25.6 * 26.0) - 128
405 constexpr int32_t value_scale = 256;
406 constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
407 int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
408
409 value += INT8_MIN; // Adds a -128; i.e., subtracts 128
410 features_buffer[i] = static_cast<int8_t>(clamp<int32_t>(value, INT8_MIN, INT8_MAX));
411 }
412
413 return processed_samples;
414}
415
417#ifdef USE_MICRO_WAKE_WORD_VAD
418 DetectionEvent vad_state = this->vad_model_->determine_detected();
419
420 this->vad_state_ = vad_state.detected; // atomic write, so thread safe
421#endif
422
423 for (auto &model : this->wake_word_models_) {
424 if (model->get_unprocessed_probability_status()) {
425 // Only detect wake words if there is a new probability since the last check
426 DetectionEvent wake_word_state = model->determine_detected();
427 if (wake_word_state.detected) {
428#ifdef USE_MICRO_WAKE_WORD_VAD
429 if (vad_state.detected) {
430#endif
431 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
432 model->reset_probabilities();
433#ifdef USE_MICRO_WAKE_WORD_VAD
434 } else {
435 wake_word_state.blocked_by_vad = true;
436 xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
437 }
438#endif
439 }
440 }
441 }
442}
443
445 for (auto &model : this->wake_word_models_) {
446 model->unload_model();
447 }
448#ifdef USE_MICRO_WAKE_WORD_VAD
449 this->vad_model_->unload_model();
450#endif
451}
452
453bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
454 bool success = true;
455
456 for (auto &model : this->wake_word_models_) {
457 // Perform inference
458 success = success & model->perform_streaming_inference(audio_features);
459 }
460#ifdef USE_MICRO_WAKE_WORD_VAD
461 success = success & this->vad_model_->perform_streaming_inference(audio_features);
462#endif
463
464 return success;
465}
466
467} // namespace micro_wake_word
468} // namespace esphome
469
470#endif // USE_ESP_IDF
virtual void mark_failed()
Mark this component as failed.
bool is_failed() const
void status_momentary_error(const std::string &name, uint32_t length=5000)
bool is_ready() const
bool status_has_error() const
static std::unique_ptr< RingBuffer > create(size_t len)
void trigger(Ts... x)
Inform the parent automation that the event has triggered.
Definition automation.h:96
static std::unique_ptr< AudioSourceTransferBuffer > create(size_t buffer_size)
Creates a new source transfer buffer.
size_t ms_to_bytes(uint32_t ms) const
Converts duration to bytes.
Definition audio.h:73
uint32_t get_sample_rate() const
Definition audio.h:30
void resume_task_()
Resumes the inference task.
microphone::MicrophoneSource * microphone_source_
void process_probabilities_()
Processes any new probabilities for each model.
std::weak_ptr< RingBuffer > ring_buffer_
Trigger< std::string > * wake_word_detected_trigger_
std::vector< WakeWordModel * > wake_word_models_
void suspend_task_()
Suspends the inference task.
void add_wake_word_model(WakeWordModel *model)
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE])
Runs an inference with each model using the new spectrogram features.
size_t generate_features_(int16_t *audio_buffer, size_t samples_available, int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE])
Generates spectrogram features from an input buffer of audio samples.
std::unique_ptr< VADModel > vad_model_
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void unload_models_()
Deletes each model's TFLite interpreters and frees tensor arena memory.
std::vector< WakeWordModel * > get_wake_words()
void add_data_callback(std::function< void(const std::vector< uint8_t > &)> &&data_callback)
audio::AudioStreamInfo get_audio_stream_info()
Gets the AudioStreamInfo of the data after processing.
void add_on_state_callback(std::function< void(OTAState, float, uint8_t, OTAComponent *)> &&callback)
Definition ota_backend.h:82
bool state
Definition fan.h:0
__int64 ssize_t
Definition httplib.h:175
OTAGlobalCallback * get_global_ota_callback()
const float AFTER_CONNECTION
For components that should be initialized after a data connection (API/MQTT) is connected.
Definition component.cpp:27
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
std::unique_ptr< T > make_unique(Args &&...args)
Definition helpers.h:85
void IRAM_ATTR HOT delay(uint32_t ms)
Definition core.cpp:28
constexpr const T & clamp(const T &v, const T &lo, const T &hi, Compare comp)
Definition helpers.h:101