ESPHome 2025.5.0
Loading...
Searching...
No Matches
streaming_model.cpp
Go to the documentation of this file.
1#include "streaming_model.h"
2
3#ifdef USE_ESP_IDF
4
6#include "esphome/core/log.h"
7
8static const char *const TAG = "micro_wake_word";
9
10namespace esphome {
11namespace micro_wake_word {
12
14 ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str());
15 ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
16 ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
17}
18
20 ESP_LOGCONFIG(TAG, " - VAD Model");
21 ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
22 ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
23}
24
27
28 if (this->tensor_arena_ == nullptr) {
29 this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
30 if (this->tensor_arena_ == nullptr) {
31 ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena.");
32 return false;
33 }
34 }
35
36 if (this->var_arena_ == nullptr) {
37 this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE);
38 if (this->var_arena_ == nullptr) {
39 ESP_LOGE(TAG, "Could not allocate the streaming model's variable tensor arena.");
40 return false;
41 }
42 this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
43 this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20);
44 }
45
46 const tflite::Model *model = tflite::GetModel(this->model_start_);
47 if (model->version() != TFLITE_SCHEMA_VERSION) {
48 ESP_LOGE(TAG, "Streaming model's schema is not supported");
49 return false;
50 }
51
52 if (this->interpreter_ == nullptr) {
53 this->interpreter_ =
55 this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
56 if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
57 ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
58 return false;
59 }
60
61 // Verify input tensor matches expected values
62 // Dimension 3 will represent the first layer stride, so skip it may vary
63 TfLiteTensor *input = this->interpreter_->input(0);
64 if ((input->dims->size != 3) || (input->dims->data[0] != 1) ||
65 (input->dims->data[2] != PREPROCESSOR_FEATURE_SIZE)) {
66 ESP_LOGE(TAG, "Streaming model tensor input dimensions has improper dimensions.");
67 return false;
68 }
69
70 if (input->type != kTfLiteInt8) {
71 ESP_LOGE(TAG, "Streaming model tensor input is not int8.");
72 return false;
73 }
74
75 // Verify output tensor matches expected values
76 TfLiteTensor *output = this->interpreter_->output(0);
77 if ((output->dims->size != 2) || (output->dims->data[0] != 1) || (output->dims->data[1] != 1)) {
78 ESP_LOGE(TAG, "Streaming model tensor output dimension is not 1x1.");
79 }
80
81 if (output->type != kTfLiteUInt8) {
82 ESP_LOGE(TAG, "Streaming model tensor output is not uint8.");
83 return false;
84 }
85 }
86
87 this->loaded_ = true;
88 this->reset_probabilities();
89 return true;
90}
91
93 this->interpreter_.reset();
94
96
97 if (this->tensor_arena_ != nullptr) {
98 arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
99 this->tensor_arena_ = nullptr;
100 }
101
102 if (this->var_arena_ != nullptr) {
103 arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
104 this->var_arena_ = nullptr;
105 }
106
107 this->loaded_ = false;
108}
109
110bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
111 if (this->enabled_ && !this->loaded_) {
112 // Model is enabled but isn't loaded
113 if (!this->load_model_()) {
114 return false;
115 }
116 }
117
118 if (!this->enabled_ && this->loaded_) {
119 // Model is disabled but still loaded
120 this->unload_model();
121 return true;
122 }
123
124 if (this->loaded_) {
125 TfLiteTensor *input = this->interpreter_->input(0);
126
127 uint8_t stride = this->interpreter_->input(0)->dims->data[1];
128 this->current_stride_step_ = this->current_stride_step_ % stride;
129
130 std::memmove(
131 (int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
132 features, PREPROCESSOR_FEATURE_SIZE);
133 ++this->current_stride_step_;
134
135 if (this->current_stride_step_ >= stride) {
136 TfLiteStatus invoke_status = this->interpreter_->Invoke();
137 if (invoke_status != kTfLiteOk) {
138 ESP_LOGW(TAG, "Streaming interpreter invoke failed");
139 return false;
140 }
141
142 TfLiteTensor *output = this->interpreter_->output(0);
143
144 ++this->last_n_index_;
145 if (this->last_n_index_ == this->sliding_window_size_)
146 this->last_n_index_ = 0;
147 this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
149 }
150 this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
151 }
152 return true;
153}
154
156 for (auto &prob : this->recent_streaming_probabilities_) {
157 prob = 0;
158 }
159 this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
160}
161
162WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
163 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
164 bool default_enabled, bool internal_only) {
165 this->id_ = id;
166 this->model_start_ = model_start;
167 this->default_probability_cutoff_ = default_probability_cutoff;
168 this->probability_cutoff_ = default_probability_cutoff;
169 this->sliding_window_size_ = sliding_window_average_size;
170 this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
171 this->wake_word_ = wake_word;
172 this->tensor_arena_size_ = tensor_arena_size;
174 this->current_stride_step_ = 0;
175 this->internal_only_ = internal_only;
176
178 bool enabled;
179 if (this->pref_.load(&enabled)) {
180 // Use the enabled state loaded from flash
181 this->enabled_ = enabled;
182 } else {
183 // If no state saved, then use the default
184 this->enabled_ = default_enabled;
185 }
186};
187
189 this->enabled_ = true;
190 if (!this->internal_only_) {
191 this->pref_.save(&this->enabled_);
192 }
193}
194
196 this->enabled_ = false;
197 if (!this->internal_only_) {
198 this->pref_.save(&this->enabled_);
199 }
200}
201
203 DetectionEvent detection_event;
204 detection_event.wake_word = &this->wake_word_;
205 detection_event.max_probability = 0;
206 detection_event.average_probability = 0;
207
208 if ((this->ignore_windows_ < 0) || !this->enabled_) {
209 detection_event.detected = false;
210 return detection_event;
211 }
212
213 uint32_t sum = 0;
214 for (auto &prob : this->recent_streaming_probabilities_) {
215 detection_event.max_probability = std::max(detection_event.max_probability, prob);
216 sum += prob;
217 }
218
219 detection_event.average_probability = sum / this->sliding_window_size_;
220 detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
221
223 return detection_event;
224}
225
226VADModel::VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
227 size_t tensor_arena_size) {
228 this->model_start_ = model_start;
229 this->default_probability_cutoff_ = default_probability_cutoff;
230 this->probability_cutoff_ = default_probability_cutoff;
231 this->sliding_window_size_ = sliding_window_size;
232 this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
233 this->tensor_arena_size_ = tensor_arena_size;
235}
236
238 DetectionEvent detection_event;
239 detection_event.max_probability = 0;
240 detection_event.average_probability = 0;
241
242 if (!this->enabled_) {
243 // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
244 detection_event.detected = true;
245 return detection_event;
246 }
247
248 uint32_t sum = 0;
249 for (auto &prob : this->recent_streaming_probabilities_) {
250 detection_event.max_probability = std::max(detection_event.max_probability, prob);
251 sum += prob;
252 }
253
254 detection_event.average_probability = sum / this->sliding_window_size_;
255 detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
256
257 return detection_event;
258}
259
260bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
261 if (op_resolver.AddCallOnce() != kTfLiteOk)
262 return false;
263 if (op_resolver.AddVarHandle() != kTfLiteOk)
264 return false;
265 if (op_resolver.AddReshape() != kTfLiteOk)
266 return false;
267 if (op_resolver.AddReadVariable() != kTfLiteOk)
268 return false;
269 if (op_resolver.AddStridedSlice() != kTfLiteOk)
270 return false;
271 if (op_resolver.AddConcatenation() != kTfLiteOk)
272 return false;
273 if (op_resolver.AddAssignVariable() != kTfLiteOk)
274 return false;
275 if (op_resolver.AddConv2D() != kTfLiteOk)
276 return false;
277 if (op_resolver.AddMul() != kTfLiteOk)
278 return false;
279 if (op_resolver.AddAdd() != kTfLiteOk)
280 return false;
281 if (op_resolver.AddMean() != kTfLiteOk)
282 return false;
283 if (op_resolver.AddFullyConnected() != kTfLiteOk)
284 return false;
285 if (op_resolver.AddLogistic() != kTfLiteOk)
286 return false;
287 if (op_resolver.AddQuantize() != kTfLiteOk)
288 return false;
289 if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
290 return false;
291 if (op_resolver.AddAveragePool2D() != kTfLiteOk)
292 return false;
293 if (op_resolver.AddMaxPool2D() != kTfLiteOk)
294 return false;
295 if (op_resolver.AddPad() != kTfLiteOk)
296 return false;
297 if (op_resolver.AddPack() != kTfLiteOk)
298 return false;
299 if (op_resolver.AddSplitV() != kTfLiteOk)
300 return false;
301
302 return true;
303}
304
305} // namespace micro_wake_word
306} // namespace esphome
307
308#endif
bool save(const T *src)
Definition preferences.h:21
virtual ESPPreferenceObject make_preference(size_t length, uint32_t type, bool in_flash)=0
An STL allocator that uses SPI or internal RAM.
Definition helpers.h:683
void deallocate(T *p, size_t n)
Definition helpers.h:741
T * allocate(size_t n)
Definition helpers.h:703
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
uint32_t fnv1_hash(const std::string &str)
Calculate a FNV-1 hash of str.
Definition helpers.cpp:186
ESPPreferences * global_preferences
std::unique_ptr< T > make_unique(Args &&...args)
Definition helpers.h:85
T id(T value)
Helper function to make id(var) known from lambdas work in custom components.
Definition helpers.h:798