ESPHome 2025.6.3
Loading...
Searching...
No Matches
streaming_model.cpp
Go to the documentation of this file.
1#include "streaming_model.h"
2
3#ifdef USE_ESP_IDF
4
6#include "esphome/core/log.h"
7
8static const char *const TAG = "micro_wake_word";
9
10namespace esphome {
11namespace micro_wake_word {
12
14 ESP_LOGCONFIG(TAG,
15 " - Wake Word: %s\n"
16 " Probability cutoff: %.2f\n"
17 " Sliding window size: %d",
18 this->wake_word_.c_str(), this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
19}
20
22 ESP_LOGCONFIG(TAG,
23 " - VAD Model\n"
24 " Probability cutoff: %.2f\n"
25 " Sliding window size: %d",
26 this->probability_cutoff_ / 255.0f, this->sliding_window_size_);
27}
28
31
32 if (this->tensor_arena_ == nullptr) {
33 this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
34 if (this->tensor_arena_ == nullptr) {
35 ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
36 return false;
37 }
38 }
39
40 if (this->var_arena_ == nullptr) {
41 this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
42 if (this->var_arena_ == nullptr) {
43 ESP_LOGE(TAG, "Could not allocate the streaming model's variable tensor arena.");
44 return false;
45 }
46 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
47 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
48 }
49
50 const tflite::Model *model = tflite::GetModel(this->model_start_);
51 if (model->version() != TFLITE_SCHEMA_VERSION) {
52 ESP_LOGE(TAG, "Streaming model's schema is not supported");
53 return false;
54 }
55
56 if (this->interpreter_ == nullptr) {
57 this->interpreter_ =
59 this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
60 if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
61 ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
62 return false;
63 }
64
65 // Verify input tensor matches expected values
66 // Dimension 3 will represent the first layer stride, so skip it may vary
67 TfLiteTensor *input = this->interpreter_->input(0);
68 if ((input->dims->size != 3) || (input->dims->data[0] != 1) ||
69 (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
70 ESP_LOGE(TAG, "Streaming model tensor input dimensions has improper dimensions.");
71 return false;
72 }
73
74 if (input->type != kTfLiteInt8) {
75 ESP_LOGE(TAG, "Streaming model tensor input is not int8.");
76 return false;
77 }
78
79 // Verify output tensor matches expected values
80 TfLiteTensor *output = this->interpreter_->output(0);
81 if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
82 ESP_LOGE(TAG, "Streaming model tensor output dimension is not 1x1.");
83 }
84
85 if (output->type != kTfLiteUInt8) {
86 ESP_LOGE(TAG, "Streaming model tensor output is not uint8.");
87 return false;
88 }
89 }
90
91 this->loaded_ = true;
92 this->reset_probabilities();
93 return true;
94}
95
97 this->interpreter_.reset();
98
100
101 if (this->tensor_arena_ != nullptr) {
102 arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
103 this->tensor_arena_ = nullptr;
104 }
105
106 if (this->var_arena_ != nullptr) {
107 arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
108 this->var_arena_ = nullptr;
109 }
110
111 this->loaded_ = false;
112}
113
114bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
115 if (this->enabled_ && !this->loaded_) {
116 // Model is enabled but isn't loaded
117 if (!this->load_model_()) {
118 return false;
119 }
120 }
121
122 if (!this->enabled_ && this->loaded_) {
123 // Model is disabled but still loaded
124 this->unload_model();
125 return true;
126 }
127
128 if (this->loaded_) {
129 TfLiteTensor *input = this->interpreter_->input(0);
130
131 uint8_t stride = this->interpreter_->input(0)->dims->data[1];
132 this->current_stride_step_ = this->current_stride_step_ % stride;
133
134 std::memmove(
135 (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
136 features, PREPROCESSOR_FEATURE_SIZE);
137 ++this->current_stride_step_;
138
139 if (this->current_stride_step_ >= stride) {
140 TfLiteStatus invoke_status = this->interpreter_->Invoke();
141 if (invoke_status != kTfLiteOk) {
142 ESP_LOGW(TAG, "Streaming interpreter invoke failed");
143 return false;
144 }
145
146 TfLiteTensor *output = this->interpreter_->output(0);
147
148 ++this->last_n_index_;
149 if (this->last_n_index_ == this->sliding_window_size_)
150 this->last_n_index_ = 0;
151 this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
153 }
155 // Only increment ignore windows if less than the probability cutoff; this forces the model to "cool-off" from a
156 // previous detection and calling ``reset_probabilities`` so it avoids duplicate detections
157 this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
158 }
159 }
160 return true;
161}
162
164 for (auto &prob : this->recent_streaming_probabilities_) {
165 prob = 0;
166 }
167 this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
168}
169
170WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
171 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
172 bool default_enabled, bool internal_only) {
173 this->id_ = id;
174 this->model_start_ = model_start;
175 this->default_probability_cutoff_ = default_probability_cutoff;
176 this->probability_cutoff_ = default_probability_cutoff;
177 this->sliding_window_size_ = sliding_window_average_size;
178 this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
179 this->wake_word_ = wake_word;
180 this->tensor_arena_size_ = tensor_arena_size;
182 this->current_stride_step_ = 0;
183 this->internal_only_ = internal_only;
184
186 bool enabled;
187 if (this->pref_.load(&enabled)) {
188 // Use the enabled state loaded from flash
189 this->enabled_ = enabled;
190 } else {
191 // If no state saved, then use the default
192 this->enabled_ = default_enabled;
193 }
194};
195
197 this->enabled_ = true;
198 if (!this->internal_only_) {
199 this->pref_.save(&this->enabled_);
200 }
201}
202
204 this->enabled_ = false;
205 if (!this->internal_only_) {
206 this->pref_.save(&this->enabled_);
207 }
208}
209
211 DetectionEvent detection_event;
212 detection_event.wake_word = &this->wake_word_;
213 detection_event.max_probability = 0;
214 detection_event.average_probability = 0;
215
216 if ((this->ignore_windows_ < 0) || !this->enabled_) {
217 detection_event.detected = false;
218 return detection_event;
219 }
220
221 uint32_t sum = 0;
222 for (auto &prob : this->recent_streaming_probabilities_) {
223 detection_event.max_probability = std::max(detection_event.max_probability, prob);
224 sum += prob;
225 }
226
227 detection_event.average_probability = sum / this->sliding_window_size_;
228 detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
229
231 return detection_event;
232}
233
234VADModel::VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
235 size_t tensor_arena_size) {
236 this->model_start_ = model_start;
237 this->default_probability_cutoff_ = default_probability_cutoff;
238 this->probability_cutoff_ = default_probability_cutoff;
239 this->sliding_window_size_ = sliding_window_size;
240 this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
241 this->tensor_arena_size_ = tensor_arena_size;
243}
244
246 DetectionEvent detection_event;
247 detection_event.max_probability = 0;
248 detection_event.average_probability = 0;
249
250 if (!this->enabled_) {
251 // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
252 detection_event.detected = true;
253 return detection_event;
254 }
255
256 uint32_t sum = 0;
257 for (auto &prob : this->recent_streaming_probabilities_) {
258 detection_event.max_probability = std::max(detection_event.max_probability, prob);
259 sum += prob;
260 }
261
262 detection_event.average_probability = sum / this->sliding_window_size_;
263 detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
264
265 return detection_event;
266}
267
268bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
269 if (op_resolver.AddCallOnce() != kTfLiteOk)
270 return false;
271 if (op_resolver.AddVarHandle() != kTfLiteOk)
272 return false;
273 if (op_resolver.AddReshape() != kTfLiteOk)
274 return false;
275 if (op_resolver.AddReadVariable() != kTfLiteOk)
276 return false;
277 if (op_resolver.AddStridedSlice() != kTfLiteOk)
278 return false;
279 if (op_resolver.AddConcatenation() != kTfLiteOk)
280 return false;
281 if (op_resolver.AddAssignVariable() != kTfLiteOk)
282 return false;
283 if (op_resolver.AddConv2D() != kTfLiteOk)
284 return false;
285 if (op_resolver.AddMul() != kTfLiteOk)
286 return false;
287 if (op_resolver.AddAdd() != kTfLiteOk)
288 return false;
289 if (op_resolver.AddMean() != kTfLiteOk)
290 return false;
291 if (op_resolver.AddFullyConnected() != kTfLiteOk)
292 return false;
293 if (op_resolver.AddLogistic() != kTfLiteOk)
294 return false;
295 if (op_resolver.AddQuantize() != kTfLiteOk)
296 return false;
297 if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
298 return false;
299 if (op_resolver.AddAveragePool2D() != kTfLiteOk)
300 return false;
301 if (op_resolver.AddMaxPool2D() != kTfLiteOk)
302 return false;
303 if (op_resolver.AddPad() != kTfLiteOk)
304 return false;
305 if (op_resolver.AddPack() != kTfLiteOk)
306 return false;
307 if (op_resolver.AddSplitV() != kTfLiteOk)
308 return false;
309
310 return true;
311}
312
313} // namespace micro_wake_word
314} // namespace esphome
315
316#endif
bool save(const T *src)
Definition preferences.h:21
virtual ESPPreferenceObject make_preference(size_t length, uint32_t type, bool in_flash)=0
An STL allocator that uses SPI or internal RAM.
Definition helpers.h:684
void deallocate(T *p, size_t n)
Definition helpers.h:742
T * allocate(size_t n)
Definition helpers.h:704
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
uint32_t fnv1_hash(const std::string &str)
Calculate a FNV-1 hash of str.
Definition helpers.cpp:186
ESPPreferences * global_preferences
std::unique_ptr< T > make_unique(Args &&...args)
Definition helpers.h:86
T id(T value)
Helper function to make id(var) known from lambdas work in custom components.
Definition helpers.h:799