ESPHome 2026.3.0
Loading...
Searching...
No Matches
streaming_model.cpp
Go to the documentation of this file.
1#include "streaming_model.h"
2
3#ifdef USE_ESP32
4
6#include "esphome/core/log.h"
7
8static const char *const TAG = "micro_wake_word";
9
10namespace esphome {
11namespace micro_wake_word {
12
14 ESP_LOGCONFIG(TAG,
15 " - Wake Word: %s\n"
16 " Probability cutoff: %.2f\n"
17 " Sliding window size: %d",
18 this->wake_word_.c_str(), this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
19}
20
22 ESP_LOGCONFIG(TAG,
23 " - VAD Model\n"
24 " Probability cutoff: %.2f\n"
25 " Sliding window size: %d",
26 this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
27}
28
30 RAMAllocator<uint8_t> arena_allocator;
31
32 if (this->tensor_arena_ == nullptr) {
33 this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
34 if (this->tensor_arena_ == nullptr) {
35 ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
36 return false;
37 }
38 }
39
40 if (this->var_arena_ == nullptr) {
41 this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
42 if (this->var_arena_ == nullptr) {
43 ESP_LOGE(TAG, "Could not allocate the streaming model's variable tensor arena.");
44 return false;
45 }
46 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
47 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
48 }
49
50 const tflite::Model *model = tflite::GetModel(this->model_start_);
51 if (model->version() != TFLITE_SCHEMA_VERSION) {
52 ESP_LOGE(TAG, "Streaming model's schema is not supported");
53 return false;
54 }
55
56 if (this->interpreter_ == nullptr) {
57 this->interpreter_ =
58 make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_,
59 this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
60 if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
61 ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
62 return false;
63 }
64
65 // Verify input tensor matches expected values
66 // Dimension 3 will represent the first layer stride, so skip it may vary
67 TfLiteTensor *input = this->interpreter_->input(0);
68 if ((input->dims->size != 3) || (input->dims->data[0] != 1) ||
69 (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
70 ESP_LOGE(TAG, "Streaming model tensor input dimensions has improper dimensions.");
71 return false;
72 }
73
74 if (input->type != kTfLiteInt8) {
75 ESP_LOGE(TAG, "Streaming model tensor input is not int8.");
76 return false;
77 }
78
79 // Verify output tensor matches expected values
80 TfLiteTensor *output = this->interpreter_->output(0);
81 if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
82 ESP_LOGE(TAG, "Streaming model tensor output dimension is not 1x1.");
83 return false;
84 }
85
86 if (output->type != kTfLiteUInt8) {
87 ESP_LOGE(TAG, "Streaming model tensor output is not uint8.");
88 return false;
89 }
90 }
91
92 this->loaded_ = true;
93 this->reset_probabilities();
94 return true;
95}
96
98 this->interpreter_.reset();
99
100 RAMAllocator<uint8_t> arena_allocator;
101
102 if (this->tensor_arena_ != nullptr) {
103 arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
104 this->tensor_arena_ = nullptr;
105 }
106
107 if (this->var_arena_ != nullptr) {
108 arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
109 this->var_arena_ = nullptr;
110 }
111
112 this->loaded_ = false;
113}
114
115bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
116 if (this->enabled_ && !this->loaded_) {
117 // Model is enabled but isn't loaded
118 if (!this->load_model_()) {
119 return false;
120 }
121 }
122
123 if (!this->enabled_ && this->loaded_) {
124 // Model is disabled but still loaded
125 this->unload_model();
126 return true;
127 }
128
129 if (this->loaded_) {
130 TfLiteTensor *input = this->interpreter_->input(0);
131
132 uint8_t stride = this->interpreter_->input(0)->dims->data[1];
133 this->current_stride_step_ = this->current_stride_step_ % stride;
134
135 std::memmove(
136 (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
137 features, PREPROCESSOR_FEATURE_SIZE);
138 ++this->current_stride_step_;
139
140 if (this->current_stride_step_ >= stride) {
141 TfLiteStatus invoke_status = this->interpreter_->Invoke();
142 if (invoke_status != kTfLiteOk) {
143 ESP_LOGW(TAG, "Streaming interpreter invoke failed");
144 return false;
145 }
146
147 TfLiteTensor *output = this->interpreter_->output(0);
148
149 ++this->last_n_index_;
150 if (this->last_n_index_ == this->sliding_window_size_)
151 this->last_n_index_ = 0;
152 this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
154 }
156 // Only increment ignore windows if less than the probability cutoff; this forces the model to "cool-off" from a
157 // previous detection and calling ``reset_probabilities`` so it avoids duplicate detections
158 this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
159 }
160 }
161 return true;
162}
163
165 for (auto &prob : this->recent_streaming_probabilities_) {
166 prob = 0;
167 }
168 this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
169}
170
171WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
172 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
173 bool default_enabled, bool internal_only) {
174 this->id_ = id;
175 this->model_start_ = model_start;
176 this->default_probability_cutoff_ = default_probability_cutoff;
177 this->probability_cutoff_ = default_probability_cutoff;
178 this->sliding_window_size_ = sliding_window_average_size;
179 this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
180 this->wake_word_ = wake_word;
181 this->tensor_arena_size_ = tensor_arena_size;
183 this->current_stride_step_ = 0;
184 this->internal_only_ = internal_only;
185
187 bool enabled;
188 if (this->pref_.load(&enabled)) {
189 // Use the enabled state loaded from flash
190 this->enabled_ = enabled;
191 } else {
192 // If no state saved, then use the default
193 this->enabled_ = default_enabled;
194 }
195};
196
198 this->enabled_ = true;
199 if (!this->internal_only_) {
200 this->pref_.save(&this->enabled_);
201 }
202}
203
205 this->enabled_ = false;
206 if (!this->internal_only_) {
207 this->pref_.save(&this->enabled_);
208 }
209}
210
212 DetectionEvent detection_event;
213 detection_event.wake_word = &this->wake_word_;
214 detection_event.max_probability = 0;
215 detection_event.average_probability = 0;
216
217 if ((this->ignore_windows_ < 0) || !this->enabled_) {
218 detection_event.detected = false;
219 return detection_event;
220 }
221
222 uint32_t sum = 0;
223 for (auto &prob : this->recent_streaming_probabilities_) {
224 detection_event.max_probability = std::max(detection_event.max_probability, prob);
225 sum += prob;
226 }
227
228 detection_event.average_probability = sum / this->sliding_window_size_;
229 detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
230
232 return detection_event;
233}
234
235VADModel::VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
236 size_t tensor_arena_size) {
237 this->model_start_ = model_start;
238 this->default_probability_cutoff_ = default_probability_cutoff;
239 this->probability_cutoff_ = default_probability_cutoff;
240 this->sliding_window_size_ = sliding_window_size;
241 this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
242 this->tensor_arena_size_ = tensor_arena_size;
244}
245
247 DetectionEvent detection_event;
248 detection_event.max_probability = 0;
249 detection_event.average_probability = 0;
250
251 if (!this->enabled_) {
252 // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
253 detection_event.detected = true;
254 return detection_event;
255 }
256
257 uint32_t sum = 0;
258 for (auto &prob : this->recent_streaming_probabilities_) {
259 detection_event.max_probability = std::max(detection_event.max_probability, prob);
260 sum += prob;
261 }
262
263 detection_event.average_probability = sum / this->sliding_window_size_;
264 detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
265
266 return detection_event;
267}
268
269bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
270 if (op_resolver.AddCallOnce() != kTfLiteOk)
271 return false;
272 if (op_resolver.AddVarHandle() != kTfLiteOk)
273 return false;
274 if (op_resolver.AddReshape() != kTfLiteOk)
275 return false;
276 if (op_resolver.AddReadVariable() != kTfLiteOk)
277 return false;
278 if (op_resolver.AddStridedSlice() != kTfLiteOk)
279 return false;
280 if (op_resolver.AddConcatenation() != kTfLiteOk)
281 return false;
282 if (op_resolver.AddAssignVariable() != kTfLiteOk)
283 return false;
284 if (op_resolver.AddConv2D() != kTfLiteOk)
285 return false;
286 if (op_resolver.AddMul() != kTfLiteOk)
287 return false;
288 if (op_resolver.AddAdd() != kTfLiteOk)
289 return false;
290 if (op_resolver.AddMean() != kTfLiteOk)
291 return false;
292 if (op_resolver.AddFullyConnected() != kTfLiteOk)
293 return false;
294 if (op_resolver.AddLogistic() != kTfLiteOk)
295 return false;
296 if (op_resolver.AddQuantize() != kTfLiteOk)
297 return false;
298 if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
299 return false;
300 if (op_resolver.AddAveragePool2D() != kTfLiteOk)
301 return false;
302 if (op_resolver.AddMaxPool2D() != kTfLiteOk)
303 return false;
304 if (op_resolver.AddPad() != kTfLiteOk)
305 return false;
306 if (op_resolver.AddPack() != kTfLiteOk)
307 return false;
308 if (op_resolver.AddSplitV() != kTfLiteOk)
309 return false;
310
311 return true;
312}
313
314} // namespace micro_wake_word
315} // namespace esphome
316
317#endif
bool save(const T *src)
Definition preferences.h:21
virtual ESPPreferenceObject make_preference(size_t length, uint32_t type, bool in_flash)=0
An STL allocator that uses SPI or internal RAM.
Definition helpers.h:1899
void deallocate(T *p, size_t n)
Definition helpers.h:1954
T * allocate(size_t n)
Definition helpers.h:1916
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
uint16_t id
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
ESPPreferences * global_preferences
uint32_t fnv1_hash(const char *str)
Calculate a FNV-1 hash of str.
Definition helpers.cpp:148
static void uint32_t